e2e: import csi-test mock driver

This is a verbatim copy of the corresponding files in csi-test v4.0.2.

They'll be modified in future commits to make the code usable when
embedded in e2e.test. Some of those changes may be worthwhile
backporting to csi-test, but this is uncertain at this time.
This commit is contained in:
Patrick Ohly 2020-11-25 08:39:44 +01:00
parent 21ffdd1a28
commit 7f2b438020
11 changed files with 2786 additions and 0 deletions

View File

@ -0,0 +1,110 @@
/*
Copyright 2019 Kubernetes Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package driver
import (
"context"
"net"
"sync"
"google.golang.org/grpc/reflection"
csi "github.com/container-storage-interface/spec/lib/go/csi"
"google.golang.org/grpc"
)
// CSIDriverControllerServer is the Controller service component of the driver.
type CSIDriverControllerServer struct {
Controller csi.ControllerServer
Identity csi.IdentityServer
}
// CSIDriverController is the CSI Driver Controller backend.
type CSIDriverController struct {
listener net.Listener
server *grpc.Server
controllerServer *CSIDriverControllerServer
wg sync.WaitGroup
running bool
lock sync.Mutex
creds *CSICreds
}
func NewCSIDriverController(controllerServer *CSIDriverControllerServer) *CSIDriverController {
return &CSIDriverController{
controllerServer: controllerServer,
}
}
func (c *CSIDriverController) goServe(started chan<- bool) {
goServe(c.server, &c.wg, c.listener, started)
}
func (c *CSIDriverController) Address() string {
return c.listener.Addr().String()
}
func (c *CSIDriverController) Start(l net.Listener) error {
c.lock.Lock()
defer c.lock.Unlock()
// Set listener.
c.listener = l
// Create a new grpc server.
c.server = grpc.NewServer(
grpc.UnaryInterceptor(c.callInterceptor),
)
if c.controllerServer.Controller != nil {
csi.RegisterControllerServer(c.server, c.controllerServer.Controller)
}
if c.controllerServer.Identity != nil {
csi.RegisterIdentityServer(c.server, c.controllerServer.Identity)
}
reflection.Register(c.server)
waitForServer := make(chan bool)
c.goServe(waitForServer)
<-waitForServer
c.running = true
return nil
}
func (c *CSIDriverController) Stop() {
stop(&c.lock, &c.wg, c.server, c.running)
}
func (c *CSIDriverController) Close() {
c.server.Stop()
}
func (c *CSIDriverController) IsRunning() bool {
c.lock.Lock()
defer c.lock.Unlock()
return c.running
}
func (c *CSIDriverController) SetDefaultCreds() {
setDefaultCreds(c.creds)
}
func (c *CSIDriverController) callInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return callInterceptor(ctx, c.creds, req, info, handler)
}

View File

@ -0,0 +1,109 @@
/*
Copyright 2019 Kubernetes Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package driver
import (
context "context"
"net"
"sync"
csi "github.com/container-storage-interface/spec/lib/go/csi"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection"
)
// CSIDriverNodeServer is the Node service component of the driver.
type CSIDriverNodeServer struct {
Node csi.NodeServer
Identity csi.IdentityServer
}
// CSIDriverNode is the CSI Driver Node backend.
type CSIDriverNode struct {
listener net.Listener
server *grpc.Server
nodeServer *CSIDriverNodeServer
wg sync.WaitGroup
running bool
lock sync.Mutex
creds *CSICreds
}
func NewCSIDriverNode(nodeServer *CSIDriverNodeServer) *CSIDriverNode {
return &CSIDriverNode{
nodeServer: nodeServer,
}
}
func (c *CSIDriverNode) goServe(started chan<- bool) {
goServe(c.server, &c.wg, c.listener, started)
}
func (c *CSIDriverNode) Address() string {
return c.listener.Addr().String()
}
func (c *CSIDriverNode) Start(l net.Listener) error {
c.lock.Lock()
defer c.lock.Unlock()
// Set listener.
c.listener = l
// Create a new grpc server.
c.server = grpc.NewServer(
grpc.UnaryInterceptor(c.callInterceptor),
)
if c.nodeServer.Node != nil {
csi.RegisterNodeServer(c.server, c.nodeServer.Node)
}
if c.nodeServer.Identity != nil {
csi.RegisterIdentityServer(c.server, c.nodeServer.Identity)
}
reflection.Register(c.server)
waitForServer := make(chan bool)
c.goServe(waitForServer)
<-waitForServer
c.running = true
return nil
}
func (c *CSIDriverNode) Stop() {
stop(&c.lock, &c.wg, c.server, c.running)
}
func (c *CSIDriverNode) Close() {
c.server.Stop()
}
func (c *CSIDriverNode) IsRunning() bool {
c.lock.Lock()
defer c.lock.Unlock()
return c.running
}
func (c *CSIDriverNode) SetDefaultCreds() {
setDefaultCreds(c.creds)
}
func (c *CSIDriverNode) callInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return callInterceptor(ctx, c.creds, req, info, handler)
}

View File

@ -0,0 +1,312 @@
/*
Copyright 2017 Luis Pabón luis@portworx.com
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
//go:generate mockgen -package=driver -destination=driver.mock.go github.com/container-storage-interface/spec/lib/go/csi IdentityServer,ControllerServer,NodeServer
package driver
import (
"context"
"encoding/json"
"errors"
"net"
"sync"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"k8s.io/klog"
"github.com/container-storage-interface/spec/lib/go/csi"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection"
)
var (
// ErrNoCredentials is the error when a secret is enabled but not passed in the request.
ErrNoCredentials = errors.New("secret must be provided")
// ErrAuthFailed is the error when the secret is incorrect.
ErrAuthFailed = errors.New("authentication failed")
)
// CSIDriverServers is a unified driver component with both Controller and Node
// services.
type CSIDriverServers struct {
Controller csi.ControllerServer
Identity csi.IdentityServer
Node csi.NodeServer
}
// This is the key name in all the CSI secret objects.
const secretField = "secretKey"
// CSICreds is a driver specific secret type. Drivers can have a key-val pair of
// secrets. This mock driver has a single string secret with secretField as the
// key.
type CSICreds struct {
CreateVolumeSecret string
DeleteVolumeSecret string
ControllerPublishVolumeSecret string
ControllerUnpublishVolumeSecret string
NodeStageVolumeSecret string
NodePublishVolumeSecret string
CreateSnapshotSecret string
DeleteSnapshotSecret string
ControllerValidateVolumeCapabilitiesSecret string
}
type CSIDriver struct {
listener net.Listener
server *grpc.Server
servers *CSIDriverServers
wg sync.WaitGroup
running bool
lock sync.Mutex
creds *CSICreds
}
func NewCSIDriver(servers *CSIDriverServers) *CSIDriver {
return &CSIDriver{
servers: servers,
}
}
func (c *CSIDriver) goServe(started chan<- bool) {
goServe(c.server, &c.wg, c.listener, started)
}
func (c *CSIDriver) Address() string {
return c.listener.Addr().String()
}
func (c *CSIDriver) Start(l net.Listener) error {
c.lock.Lock()
defer c.lock.Unlock()
// Set listener
c.listener = l
// Create a new grpc server
c.server = grpc.NewServer(
grpc.UnaryInterceptor(c.callInterceptor),
)
// Register Mock servers
if c.servers.Controller != nil {
csi.RegisterControllerServer(c.server, c.servers.Controller)
}
if c.servers.Identity != nil {
csi.RegisterIdentityServer(c.server, c.servers.Identity)
}
if c.servers.Node != nil {
csi.RegisterNodeServer(c.server, c.servers.Node)
}
reflection.Register(c.server)
// Start listening for requests
waitForServer := make(chan bool)
c.goServe(waitForServer)
<-waitForServer
c.running = true
return nil
}
func (c *CSIDriver) Stop() {
stop(&c.lock, &c.wg, c.server, c.running)
}
func (c *CSIDriver) Close() {
c.server.Stop()
}
func (c *CSIDriver) IsRunning() bool {
c.lock.Lock()
defer c.lock.Unlock()
return c.running
}
// SetDefaultCreds sets the default secrets for CSI creds.
func (c *CSIDriver) SetDefaultCreds() {
setDefaultCreds(c.creds)
}
func (c *CSIDriver) callInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return callInterceptor(ctx, c.creds, req, info, handler)
}
// goServe starts a grpc server.
func goServe(server *grpc.Server, wg *sync.WaitGroup, listener net.Listener, started chan<- bool) {
wg.Add(1)
go func() {
defer wg.Done()
started <- true
err := server.Serve(listener)
if err != nil {
panic(err.Error())
}
}()
}
// stop stops a grpc server.
func stop(lock *sync.Mutex, wg *sync.WaitGroup, server *grpc.Server, running bool) {
lock.Lock()
defer lock.Unlock()
if !running {
return
}
server.Stop()
wg.Wait()
}
// setDefaultCreds sets the default credentials, given a CSICreds instance.
func setDefaultCreds(creds *CSICreds) {
creds = &CSICreds{
CreateVolumeSecret: "secretval1",
DeleteVolumeSecret: "secretval2",
ControllerPublishVolumeSecret: "secretval3",
ControllerUnpublishVolumeSecret: "secretval4",
NodeStageVolumeSecret: "secretval5",
NodePublishVolumeSecret: "secretval6",
CreateSnapshotSecret: "secretval7",
DeleteSnapshotSecret: "secretval8",
ControllerValidateVolumeCapabilitiesSecret: "secretval9",
}
}
func callInterceptor(ctx context.Context, creds *CSICreds, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
err := authInterceptor(creds, req)
if err != nil {
logGRPC(info.FullMethod, req, nil, err)
return nil, err
}
rsp, err := handler(ctx, req)
logGRPC(info.FullMethod, req, rsp, err)
return rsp, err
}
func authInterceptor(creds *CSICreds, req interface{}) error {
if creds != nil {
authenticated, authErr := isAuthenticated(req, creds)
if !authenticated {
if authErr == ErrNoCredentials {
return status.Error(codes.InvalidArgument, authErr.Error())
}
if authErr == ErrAuthFailed {
return status.Error(codes.Unauthenticated, authErr.Error())
}
}
}
return nil
}
func logGRPC(method string, request, reply interface{}, err error) {
// Log JSON with the request and response for easier parsing
logMessage := struct {
Method string
Request interface{}
Response interface{}
// Error as string, for backward compatibility.
// "" on no error.
Error string
// Full error dump, to be able to parse out full gRPC error code and message separately in a test.
FullError error
}{
Method: method,
Request: request,
Response: reply,
FullError: err,
}
if err != nil {
logMessage.Error = err.Error()
}
msg, _ := json.Marshal(logMessage)
klog.V(3).Infof("gRPCCall: %s\n", msg)
}
func isAuthenticated(req interface{}, creds *CSICreds) (bool, error) {
switch r := req.(type) {
case *csi.CreateVolumeRequest:
return authenticateCreateVolume(r, creds)
case *csi.DeleteVolumeRequest:
return authenticateDeleteVolume(r, creds)
case *csi.ControllerPublishVolumeRequest:
return authenticateControllerPublishVolume(r, creds)
case *csi.ControllerUnpublishVolumeRequest:
return authenticateControllerUnpublishVolume(r, creds)
case *csi.NodeStageVolumeRequest:
return authenticateNodeStageVolume(r, creds)
case *csi.NodePublishVolumeRequest:
return authenticateNodePublishVolume(r, creds)
case *csi.CreateSnapshotRequest:
return authenticateCreateSnapshot(r, creds)
case *csi.DeleteSnapshotRequest:
return authenticateDeleteSnapshot(r, creds)
case *csi.ValidateVolumeCapabilitiesRequest:
return authenticateControllerValidateVolumeCapabilities(r, creds)
default:
return true, nil
}
}
func authenticateCreateVolume(req *csi.CreateVolumeRequest, creds *CSICreds) (bool, error) {
return credsCheck(req.GetSecrets(), creds.CreateVolumeSecret)
}
func authenticateDeleteVolume(req *csi.DeleteVolumeRequest, creds *CSICreds) (bool, error) {
return credsCheck(req.GetSecrets(), creds.DeleteVolumeSecret)
}
func authenticateControllerPublishVolume(req *csi.ControllerPublishVolumeRequest, creds *CSICreds) (bool, error) {
return credsCheck(req.GetSecrets(), creds.ControllerPublishVolumeSecret)
}
func authenticateControllerUnpublishVolume(req *csi.ControllerUnpublishVolumeRequest, creds *CSICreds) (bool, error) {
return credsCheck(req.GetSecrets(), creds.ControllerUnpublishVolumeSecret)
}
func authenticateNodeStageVolume(req *csi.NodeStageVolumeRequest, creds *CSICreds) (bool, error) {
return credsCheck(req.GetSecrets(), creds.NodeStageVolumeSecret)
}
func authenticateNodePublishVolume(req *csi.NodePublishVolumeRequest, creds *CSICreds) (bool, error) {
return credsCheck(req.GetSecrets(), creds.NodePublishVolumeSecret)
}
func authenticateCreateSnapshot(req *csi.CreateSnapshotRequest, creds *CSICreds) (bool, error) {
return credsCheck(req.GetSecrets(), creds.CreateSnapshotSecret)
}
func authenticateDeleteSnapshot(req *csi.DeleteSnapshotRequest, creds *CSICreds) (bool, error) {
return credsCheck(req.GetSecrets(), creds.DeleteSnapshotSecret)
}
func authenticateControllerValidateVolumeCapabilities(req *csi.ValidateVolumeCapabilitiesRequest, creds *CSICreds) (bool, error) {
return credsCheck(req.GetSecrets(), creds.ControllerValidateVolumeCapabilitiesSecret)
}
func credsCheck(secrets map[string]string, secretVal string) (bool, error) {
if len(secrets) == 0 {
return false, ErrNoCredentials
}
if secrets[secretField] != secretVal {
return false, ErrAuthFailed
}
return true, nil
}

View File

@ -0,0 +1,392 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/container-storage-interface/spec/lib/go/csi (interfaces: IdentityServer,ControllerServer,NodeServer)
// Package driver is a generated GoMock package.
package driver
import (
context "context"
csi "github.com/container-storage-interface/spec/lib/go/csi"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockIdentityServer is a mock of IdentityServer interface
type MockIdentityServer struct {
ctrl *gomock.Controller
recorder *MockIdentityServerMockRecorder
}
// MockIdentityServerMockRecorder is the mock recorder for MockIdentityServer
type MockIdentityServerMockRecorder struct {
mock *MockIdentityServer
}
// NewMockIdentityServer creates a new mock instance
func NewMockIdentityServer(ctrl *gomock.Controller) *MockIdentityServer {
mock := &MockIdentityServer{ctrl: ctrl}
mock.recorder = &MockIdentityServerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockIdentityServer) EXPECT() *MockIdentityServerMockRecorder {
return m.recorder
}
// GetPluginCapabilities mocks base method
func (m *MockIdentityServer) GetPluginCapabilities(arg0 context.Context, arg1 *csi.GetPluginCapabilitiesRequest) (*csi.GetPluginCapabilitiesResponse, error) {
ret := m.ctrl.Call(m, "GetPluginCapabilities", arg0, arg1)
ret0, _ := ret[0].(*csi.GetPluginCapabilitiesResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPluginCapabilities indicates an expected call of GetPluginCapabilities
func (mr *MockIdentityServerMockRecorder) GetPluginCapabilities(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPluginCapabilities", reflect.TypeOf((*MockIdentityServer)(nil).GetPluginCapabilities), arg0, arg1)
}
// GetPluginInfo mocks base method
func (m *MockIdentityServer) GetPluginInfo(arg0 context.Context, arg1 *csi.GetPluginInfoRequest) (*csi.GetPluginInfoResponse, error) {
ret := m.ctrl.Call(m, "GetPluginInfo", arg0, arg1)
ret0, _ := ret[0].(*csi.GetPluginInfoResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPluginInfo indicates an expected call of GetPluginInfo
func (mr *MockIdentityServerMockRecorder) GetPluginInfo(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPluginInfo", reflect.TypeOf((*MockIdentityServer)(nil).GetPluginInfo), arg0, arg1)
}
// Probe mocks base method
func (m *MockIdentityServer) Probe(arg0 context.Context, arg1 *csi.ProbeRequest) (*csi.ProbeResponse, error) {
ret := m.ctrl.Call(m, "Probe", arg0, arg1)
ret0, _ := ret[0].(*csi.ProbeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Probe indicates an expected call of Probe
func (mr *MockIdentityServerMockRecorder) Probe(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Probe", reflect.TypeOf((*MockIdentityServer)(nil).Probe), arg0, arg1)
}
// MockControllerServer is a mock of ControllerServer interface
type MockControllerServer struct {
ctrl *gomock.Controller
recorder *MockControllerServerMockRecorder
}
// MockControllerServerMockRecorder is the mock recorder for MockControllerServer
type MockControllerServerMockRecorder struct {
mock *MockControllerServer
}
// NewMockControllerServer creates a new mock instance
func NewMockControllerServer(ctrl *gomock.Controller) *MockControllerServer {
mock := &MockControllerServer{ctrl: ctrl}
mock.recorder = &MockControllerServerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockControllerServer) EXPECT() *MockControllerServerMockRecorder {
return m.recorder
}
// ControllerExpandVolume mocks base method
func (m *MockControllerServer) ControllerExpandVolume(arg0 context.Context, arg1 *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) {
ret := m.ctrl.Call(m, "ControllerExpandVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.ControllerExpandVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ControllerExpandVolume indicates an expected call of ControllerExpandVolume
func (mr *MockControllerServerMockRecorder) ControllerExpandVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ControllerExpandVolume", reflect.TypeOf((*MockControllerServer)(nil).ControllerExpandVolume), arg0, arg1)
}
// ControllerGetCapabilities mocks base method
func (m *MockControllerServer) ControllerGetCapabilities(arg0 context.Context, arg1 *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) {
ret := m.ctrl.Call(m, "ControllerGetCapabilities", arg0, arg1)
ret0, _ := ret[0].(*csi.ControllerGetCapabilitiesResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ControllerGetCapabilities indicates an expected call of ControllerGetCapabilities
func (mr *MockControllerServerMockRecorder) ControllerGetCapabilities(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ControllerGetCapabilities", reflect.TypeOf((*MockControllerServer)(nil).ControllerGetCapabilities), arg0, arg1)
}
// ControllerPublishVolume mocks base method
func (m *MockControllerServer) ControllerPublishVolume(arg0 context.Context, arg1 *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) {
ret := m.ctrl.Call(m, "ControllerPublishVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.ControllerPublishVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ControllerPublishVolume indicates an expected call of ControllerPublishVolume
func (mr *MockControllerServerMockRecorder) ControllerPublishVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ControllerPublishVolume", reflect.TypeOf((*MockControllerServer)(nil).ControllerPublishVolume), arg0, arg1)
}
// ControllerUnpublishVolume mocks base method
func (m *MockControllerServer) ControllerUnpublishVolume(arg0 context.Context, arg1 *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) {
ret := m.ctrl.Call(m, "ControllerUnpublishVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.ControllerUnpublishVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ControllerUnpublishVolume indicates an expected call of ControllerUnpublishVolume
func (mr *MockControllerServerMockRecorder) ControllerUnpublishVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ControllerUnpublishVolume", reflect.TypeOf((*MockControllerServer)(nil).ControllerUnpublishVolume), arg0, arg1)
}
// CreateSnapshot mocks base method
func (m *MockControllerServer) CreateSnapshot(arg0 context.Context, arg1 *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) {
ret := m.ctrl.Call(m, "CreateSnapshot", arg0, arg1)
ret0, _ := ret[0].(*csi.CreateSnapshotResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateSnapshot indicates an expected call of CreateSnapshot
func (mr *MockControllerServerMockRecorder) CreateSnapshot(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSnapshot", reflect.TypeOf((*MockControllerServer)(nil).CreateSnapshot), arg0, arg1)
}
// CreateVolume mocks base method
func (m *MockControllerServer) CreateVolume(arg0 context.Context, arg1 *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) {
ret := m.ctrl.Call(m, "CreateVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.CreateVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateVolume indicates an expected call of CreateVolume
func (mr *MockControllerServerMockRecorder) CreateVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateVolume", reflect.TypeOf((*MockControllerServer)(nil).CreateVolume), arg0, arg1)
}
// DeleteSnapshot mocks base method
func (m *MockControllerServer) DeleteSnapshot(arg0 context.Context, arg1 *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) {
ret := m.ctrl.Call(m, "DeleteSnapshot", arg0, arg1)
ret0, _ := ret[0].(*csi.DeleteSnapshotResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteSnapshot indicates an expected call of DeleteSnapshot
func (mr *MockControllerServerMockRecorder) DeleteSnapshot(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSnapshot", reflect.TypeOf((*MockControllerServer)(nil).DeleteSnapshot), arg0, arg1)
}
// DeleteVolume mocks base method
func (m *MockControllerServer) DeleteVolume(arg0 context.Context, arg1 *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) {
ret := m.ctrl.Call(m, "DeleteVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.DeleteVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteVolume indicates an expected call of DeleteVolume
func (mr *MockControllerServerMockRecorder) DeleteVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteVolume", reflect.TypeOf((*MockControllerServer)(nil).DeleteVolume), arg0, arg1)
}
// GetCapacity mocks base method
func (m *MockControllerServer) GetCapacity(arg0 context.Context, arg1 *csi.GetCapacityRequest) (*csi.GetCapacityResponse, error) {
ret := m.ctrl.Call(m, "GetCapacity", arg0, arg1)
ret0, _ := ret[0].(*csi.GetCapacityResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetCapacity indicates an expected call of GetCapacity
func (mr *MockControllerServerMockRecorder) GetCapacity(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCapacity", reflect.TypeOf((*MockControllerServer)(nil).GetCapacity), arg0, arg1)
}
// ListSnapshots mocks base method
func (m *MockControllerServer) ListSnapshots(arg0 context.Context, arg1 *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
ret := m.ctrl.Call(m, "ListSnapshots", arg0, arg1)
ret0, _ := ret[0].(*csi.ListSnapshotsResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListSnapshots indicates an expected call of ListSnapshots
func (mr *MockControllerServerMockRecorder) ListSnapshots(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListSnapshots", reflect.TypeOf((*MockControllerServer)(nil).ListSnapshots), arg0, arg1)
}
// ListVolumes mocks base method
func (m *MockControllerServer) ListVolumes(arg0 context.Context, arg1 *csi.ListVolumesRequest) (*csi.ListVolumesResponse, error) {
ret := m.ctrl.Call(m, "ListVolumes", arg0, arg1)
ret0, _ := ret[0].(*csi.ListVolumesResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
func (m *MockControllerServer) ControllerGetVolume(arg0 context.Context, arg1 *csi.ControllerGetVolumeRequest) (*csi.ControllerGetVolumeResponse, error) {
ret := m.ctrl.Call(m, "ControllerGetVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.ControllerGetVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ControllerGetVolume indicates an expected call of ControllerGetVolume
func (mr *MockControllerServerMockRecorder) ControllerGetVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ControllerGetVolume", reflect.TypeOf((*MockControllerServer)(nil).ControllerGetVolume), arg0, arg1)
}
// ListVolumes indicates an expected call of ListVolumes
func (mr *MockControllerServerMockRecorder) ListVolumes(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListVolumes", reflect.TypeOf((*MockControllerServer)(nil).ListVolumes), arg0, arg1)
}
// ValidateVolumeCapabilities mocks base method
func (m *MockControllerServer) ValidateVolumeCapabilities(arg0 context.Context, arg1 *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) {
ret := m.ctrl.Call(m, "ValidateVolumeCapabilities", arg0, arg1)
ret0, _ := ret[0].(*csi.ValidateVolumeCapabilitiesResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ValidateVolumeCapabilities indicates an expected call of ValidateVolumeCapabilities
func (mr *MockControllerServerMockRecorder) ValidateVolumeCapabilities(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateVolumeCapabilities", reflect.TypeOf((*MockControllerServer)(nil).ValidateVolumeCapabilities), arg0, arg1)
}
// MockNodeServer is a mock of NodeServer interface
type MockNodeServer struct {
ctrl *gomock.Controller
recorder *MockNodeServerMockRecorder
}
// MockNodeServerMockRecorder is the mock recorder for MockNodeServer
type MockNodeServerMockRecorder struct {
mock *MockNodeServer
}
// NewMockNodeServer creates a new mock instance
func NewMockNodeServer(ctrl *gomock.Controller) *MockNodeServer {
mock := &MockNodeServer{ctrl: ctrl}
mock.recorder = &MockNodeServerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockNodeServer) EXPECT() *MockNodeServerMockRecorder {
return m.recorder
}
// NodeExpandVolume mocks base method
func (m *MockNodeServer) NodeExpandVolume(arg0 context.Context, arg1 *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) {
ret := m.ctrl.Call(m, "NodeExpandVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.NodeExpandVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NodeExpandVolume indicates an expected call of NodeExpandVolume
func (mr *MockNodeServerMockRecorder) NodeExpandVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeExpandVolume", reflect.TypeOf((*MockNodeServer)(nil).NodeExpandVolume), arg0, arg1)
}
// NodeGetCapabilities mocks base method
func (m *MockNodeServer) NodeGetCapabilities(arg0 context.Context, arg1 *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) {
ret := m.ctrl.Call(m, "NodeGetCapabilities", arg0, arg1)
ret0, _ := ret[0].(*csi.NodeGetCapabilitiesResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NodeGetCapabilities indicates an expected call of NodeGetCapabilities
func (mr *MockNodeServerMockRecorder) NodeGetCapabilities(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeGetCapabilities", reflect.TypeOf((*MockNodeServer)(nil).NodeGetCapabilities), arg0, arg1)
}
// NodeGetInfo mocks base method
func (m *MockNodeServer) NodeGetInfo(arg0 context.Context, arg1 *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) {
ret := m.ctrl.Call(m, "NodeGetInfo", arg0, arg1)
ret0, _ := ret[0].(*csi.NodeGetInfoResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NodeGetInfo indicates an expected call of NodeGetInfo
func (mr *MockNodeServerMockRecorder) NodeGetInfo(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeGetInfo", reflect.TypeOf((*MockNodeServer)(nil).NodeGetInfo), arg0, arg1)
}
// NodeGetVolumeStats mocks base method
func (m *MockNodeServer) NodeGetVolumeStats(arg0 context.Context, arg1 *csi.NodeGetVolumeStatsRequest) (*csi.NodeGetVolumeStatsResponse, error) {
ret := m.ctrl.Call(m, "NodeGetVolumeStats", arg0, arg1)
ret0, _ := ret[0].(*csi.NodeGetVolumeStatsResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NodeGetVolumeStats indicates an expected call of NodeGetVolumeStats
func (mr *MockNodeServerMockRecorder) NodeGetVolumeStats(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeGetVolumeStats", reflect.TypeOf((*MockNodeServer)(nil).NodeGetVolumeStats), arg0, arg1)
}
// NodePublishVolume mocks base method
func (m *MockNodeServer) NodePublishVolume(arg0 context.Context, arg1 *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) {
ret := m.ctrl.Call(m, "NodePublishVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.NodePublishVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NodePublishVolume indicates an expected call of NodePublishVolume
func (mr *MockNodeServerMockRecorder) NodePublishVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodePublishVolume", reflect.TypeOf((*MockNodeServer)(nil).NodePublishVolume), arg0, arg1)
}
// NodeStageVolume mocks base method
func (m *MockNodeServer) NodeStageVolume(arg0 context.Context, arg1 *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) {
ret := m.ctrl.Call(m, "NodeStageVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.NodeStageVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NodeStageVolume indicates an expected call of NodeStageVolume
func (mr *MockNodeServerMockRecorder) NodeStageVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeStageVolume", reflect.TypeOf((*MockNodeServer)(nil).NodeStageVolume), arg0, arg1)
}
// NodeUnpublishVolume mocks base method
func (m *MockNodeServer) NodeUnpublishVolume(arg0 context.Context, arg1 *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) {
ret := m.ctrl.Call(m, "NodeUnpublishVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.NodeUnpublishVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NodeUnpublishVolume indicates an expected call of NodeUnpublishVolume
func (mr *MockNodeServerMockRecorder) NodeUnpublishVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeUnpublishVolume", reflect.TypeOf((*MockNodeServer)(nil).NodeUnpublishVolume), arg0, arg1)
}
// NodeUnstageVolume mocks base method
func (m *MockNodeServer) NodeUnstageVolume(arg0 context.Context, arg1 *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) {
ret := m.ctrl.Call(m, "NodeUnstageVolume", arg0, arg1)
ret0, _ := ret[0].(*csi.NodeUnstageVolumeResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NodeUnstageVolume indicates an expected call of NodeUnstageVolume
func (mr *MockNodeServerMockRecorder) NodeUnstageVolume(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeUnstageVolume", reflect.TypeOf((*MockNodeServer)(nil).NodeUnstageVolume), arg0, arg1)
}

View File

@ -0,0 +1,89 @@
/*
Copyright 2017 Luis Pabón luis@portworx.com
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package driver
import (
"net"
"github.com/kubernetes-csi/csi-test/v4/utils"
"google.golang.org/grpc"
)
type MockCSIDriverServers struct {
Controller *MockControllerServer
Identity *MockIdentityServer
Node *MockNodeServer
}
type MockCSIDriver struct {
CSIDriver
conn *grpc.ClientConn
}
func NewMockCSIDriver(servers *MockCSIDriverServers) *MockCSIDriver {
return &MockCSIDriver{
CSIDriver: CSIDriver{
servers: &CSIDriverServers{
Controller: servers.Controller,
Node: servers.Node,
Identity: servers.Identity,
},
},
}
}
// StartOnAddress starts a new gRPC server listening on given address.
func (m *MockCSIDriver) StartOnAddress(network, address string) error {
l, err := net.Listen(network, address)
if err != nil {
return err
}
if err := m.CSIDriver.Start(l); err != nil {
l.Close()
return err
}
return nil
}
// Start starts a new gRPC server listening on a random TCP loopback port.
func (m *MockCSIDriver) Start() error {
// Listen on a port assigned by the net package
return m.StartOnAddress("tcp", "127.0.0.1:0")
}
func (m *MockCSIDriver) Nexus() (*grpc.ClientConn, error) {
// Start server
err := m.Start()
if err != nil {
return nil, err
}
// Create a client connection
m.conn, err = utils.Connect(m.Address(), grpc.WithInsecure())
if err != nil {
return nil, err
}
return m.conn, nil
}
func (m *MockCSIDriver) Close() {
m.conn.Close()
m.server.Stop()
}

View File

@ -0,0 +1,89 @@
package cache
import (
"strings"
"sync"
"github.com/container-storage-interface/spec/lib/go/csi"
)
type SnapshotCache interface {
Add(snapshot Snapshot)
Delete(i int)
List(ready bool) []csi.Snapshot
FindSnapshot(k, v string) (int, Snapshot)
}
type Snapshot struct {
Name string
Parameters map[string]string
SnapshotCSI csi.Snapshot
}
type snapshotCache struct {
snapshotsRWL sync.RWMutex
snapshots []Snapshot
}
func NewSnapshotCache() SnapshotCache {
return &snapshotCache{
snapshots: make([]Snapshot, 0),
}
}
func (snap *snapshotCache) Add(snapshot Snapshot) {
snap.snapshotsRWL.Lock()
defer snap.snapshotsRWL.Unlock()
snap.snapshots = append(snap.snapshots, snapshot)
}
func (snap *snapshotCache) Delete(i int) {
snap.snapshotsRWL.Lock()
defer snap.snapshotsRWL.Unlock()
copy(snap.snapshots[i:], snap.snapshots[i+1:])
snap.snapshots = snap.snapshots[:len(snap.snapshots)-1]
}
func (snap *snapshotCache) List(ready bool) []csi.Snapshot {
snap.snapshotsRWL.RLock()
defer snap.snapshotsRWL.RUnlock()
snapshots := make([]csi.Snapshot, 0)
for _, v := range snap.snapshots {
if v.SnapshotCSI.GetReadyToUse() {
snapshots = append(snapshots, v.SnapshotCSI)
}
}
return snapshots
}
func (snap *snapshotCache) FindSnapshot(k, v string) (int, Snapshot) {
snap.snapshotsRWL.RLock()
defer snap.snapshotsRWL.RUnlock()
snapshotIdx := -1
for i, vi := range snap.snapshots {
switch k {
case "id":
if strings.EqualFold(v, vi.SnapshotCSI.GetSnapshotId()) {
return i, vi
}
case "sourceVolumeId":
if strings.EqualFold(v, vi.SnapshotCSI.SourceVolumeId) {
return i, vi
}
case "name":
if vi.Name == v {
return i, vi
}
}
}
return snapshotIdx, Snapshot{}
}

View File

@ -0,0 +1,834 @@
package service
import (
"fmt"
"math"
"path"
"reflect"
"strconv"
"github.com/container-storage-interface/spec/lib/go/csi"
log "github.com/sirupsen/logrus"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const (
MaxStorageCapacity = tib
ReadOnlyKey = "readonly"
)
func (s *service) CreateVolume(
ctx context.Context,
req *csi.CreateVolumeRequest) (
*csi.CreateVolumeResponse, error) {
if len(req.Name) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume Name cannot be empty")
}
if req.VolumeCapabilities == nil {
return nil, status.Error(codes.InvalidArgument, "Volume Capabilities cannot be empty")
}
if hookVal, hookMsg := s.execHook("CreateVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
// Check to see if the volume already exists.
if i, v := s.findVolByName(ctx, req.Name); i >= 0 {
// Requested volume name already exists, need to check if the existing volume's
// capacity is more or equal to new request's capacity.
if v.GetCapacityBytes() < req.GetCapacityRange().GetRequiredBytes() {
return nil, status.Error(codes.AlreadyExists,
fmt.Sprintf("Volume with name %s already exists", req.GetName()))
}
return &csi.CreateVolumeResponse{Volume: &v}, nil
}
// If no capacity is specified then use 100GiB
capacity := gib100
if cr := req.CapacityRange; cr != nil {
if rb := cr.RequiredBytes; rb > 0 {
capacity = rb
}
if lb := cr.LimitBytes; lb > 0 {
capacity = lb
}
}
// Check for maximum available capacity
if capacity >= MaxStorageCapacity {
return nil, status.Errorf(codes.OutOfRange, "Requested capacity %d exceeds maximum allowed %d", capacity, MaxStorageCapacity)
}
var v csi.Volume
// Create volume from content source if provided.
if req.GetVolumeContentSource() != nil {
switch req.GetVolumeContentSource().GetType().(type) {
case *csi.VolumeContentSource_Snapshot:
sid := req.GetVolumeContentSource().GetSnapshot().GetSnapshotId()
// Check if the source snapshot exists.
if snapID, _ := s.snapshots.FindSnapshot("id", sid); snapID >= 0 {
v = s.newVolumeFromSnapshot(req.Name, capacity, snapID)
} else {
return nil, status.Errorf(codes.NotFound, "Requested source snapshot %s not found", sid)
}
case *csi.VolumeContentSource_Volume:
vid := req.GetVolumeContentSource().GetVolume().GetVolumeId()
// Check if the source volume exists.
if volID, _ := s.findVolNoLock("id", vid); volID >= 0 {
v = s.newVolumeFromVolume(req.Name, capacity, volID)
} else {
return nil, status.Errorf(codes.NotFound, "Requested source volume %s not found", vid)
}
}
} else {
v = s.newVolume(req.Name, capacity)
}
// Add the created volume to the service's in-mem volume slice.
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
s.vols = append(s.vols, v)
MockVolumes[v.GetVolumeId()] = Volume{
VolumeCSI: v,
NodeID: "",
ISStaged: false,
ISPublished: false,
StageTargetPath: "",
TargetPath: "",
}
if hookVal, hookMsg := s.execHook("CreateVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.CreateVolumeResponse{Volume: &v}, nil
}
func (s *service) DeleteVolume(
ctx context.Context,
req *csi.DeleteVolumeRequest) (
*csi.DeleteVolumeResponse, error) {
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
// If the volume is not specified, return error
if len(req.VolumeId) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if hookVal, hookMsg := s.execHook("DeleteVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
// If the volume does not exist then return an idempotent response.
i, _ := s.findVolNoLock("id", req.VolumeId)
if i < 0 {
return &csi.DeleteVolumeResponse{}, nil
}
// This delete logic preserves order and prevents potential memory
// leaks. The slice's elements may not be pointers, but the structs
// themselves have fields that are.
copy(s.vols[i:], s.vols[i+1:])
s.vols[len(s.vols)-1] = csi.Volume{}
s.vols = s.vols[:len(s.vols)-1]
log.WithField("volumeID", req.VolumeId).Debug("mock delete volume")
if hookVal, hookMsg := s.execHook("DeleteVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.DeleteVolumeResponse{}, nil
}
func (s *service) ControllerPublishVolume(
ctx context.Context,
req *csi.ControllerPublishVolumeRequest) (
*csi.ControllerPublishVolumeResponse, error) {
if s.config.DisableAttach {
return nil, status.Error(codes.Unimplemented, "ControllerPublish is not supported")
}
if len(req.VolumeId) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if len(req.NodeId) == 0 {
return nil, status.Error(codes.InvalidArgument, "Node ID cannot be empty")
}
if req.VolumeCapability == nil {
return nil, status.Error(codes.InvalidArgument, "Volume Capabilities cannot be empty")
}
if req.NodeId != s.nodeID {
return nil, status.Errorf(codes.NotFound, "Not matching Node ID %s to Mock Node ID %s", req.NodeId, s.nodeID)
}
if hookVal, hookMsg := s.execHook("ControllerPublishVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
i, v := s.findVolNoLock("id", req.VolumeId)
if i < 0 {
return nil, status.Error(codes.NotFound, req.VolumeId)
}
// devPathKey is the key in the volume's attributes that is set to a
// mock device path if the volume has been published by the controller
// to the specified node.
devPathKey := path.Join(req.NodeId, "dev")
// Check to see if the volume is already published.
if device := v.VolumeContext[devPathKey]; device != "" {
var volRo bool
var roVal string
if ro, ok := v.VolumeContext[ReadOnlyKey]; ok {
roVal = ro
}
if roVal == "true" {
volRo = true
} else {
volRo = false
}
// Check if readonly flag is compatible with the publish request.
if req.GetReadonly() != volRo {
return nil, status.Error(codes.AlreadyExists, "Volume published but has incompatible readonly flag")
}
return &csi.ControllerPublishVolumeResponse{
PublishContext: map[string]string{
"device": device,
"readonly": roVal,
},
}, nil
}
// Check attach limit before publishing only if attach limit is set.
if s.config.AttachLimit > 0 && s.getAttachCount(devPathKey) >= s.config.AttachLimit {
return nil, status.Errorf(codes.ResourceExhausted, "Cannot attach any more volumes to this node")
}
var roVal string
if req.GetReadonly() {
roVal = "true"
} else {
roVal = "false"
}
// Publish the volume.
device := "/dev/mock"
v.VolumeContext[devPathKey] = device
v.VolumeContext[ReadOnlyKey] = roVal
s.vols[i] = v
if volInfo, ok := MockVolumes[req.VolumeId]; ok {
volInfo.ISControllerPublished = true
MockVolumes[req.VolumeId] = volInfo
}
if hookVal, hookMsg := s.execHook("ControllerPublishVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.ControllerPublishVolumeResponse{
PublishContext: map[string]string{
"device": device,
"readonly": roVal,
},
}, nil
}
func (s *service) ControllerUnpublishVolume(
ctx context.Context,
req *csi.ControllerUnpublishVolumeRequest) (
*csi.ControllerUnpublishVolumeResponse, error) {
if s.config.DisableAttach {
return nil, status.Error(codes.Unimplemented, "ControllerPublish is not supported")
}
if len(req.VolumeId) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
nodeID := req.NodeId
if len(nodeID) == 0 {
// If node id is empty, no failure as per Spec
nodeID = s.nodeID
}
if req.NodeId != s.nodeID {
return nil, status.Errorf(codes.NotFound, "Node ID %s does not match to expected Node ID %s", req.NodeId, s.nodeID)
}
if hookVal, hookMsg := s.execHook("ControllerUnpublishVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
i, v := s.findVolNoLock("id", req.VolumeId)
if i < 0 {
// Not an error: a non-existent volume is not published.
// See also https://github.com/kubernetes-csi/external-attacher/pull/165
return &csi.ControllerUnpublishVolumeResponse{}, nil
}
// devPathKey is the key in the volume's attributes that is set to a
// mock device path if the volume has been published by the controller
// to the specified node.
devPathKey := path.Join(nodeID, "dev")
// Check to see if the volume is already unpublished.
if v.VolumeContext[devPathKey] == "" {
return &csi.ControllerUnpublishVolumeResponse{}, nil
}
// Unpublish the volume.
delete(v.VolumeContext, devPathKey)
delete(v.VolumeContext, ReadOnlyKey)
s.vols[i] = v
if hookVal, hookMsg := s.execHook("ControllerUnpublishVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.ControllerUnpublishVolumeResponse{}, nil
}
func (s *service) ValidateVolumeCapabilities(
ctx context.Context,
req *csi.ValidateVolumeCapabilitiesRequest) (
*csi.ValidateVolumeCapabilitiesResponse, error) {
if len(req.GetVolumeId()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if len(req.VolumeCapabilities) == 0 {
return nil, status.Error(codes.InvalidArgument, req.VolumeId)
}
i, _ := s.findVolNoLock("id", req.VolumeId)
if i < 0 {
return nil, status.Error(codes.NotFound, req.VolumeId)
}
if hookVal, hookMsg := s.execHook("ValidateVolumeCapabilities"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.ValidateVolumeCapabilitiesResponse{
Confirmed: &csi.ValidateVolumeCapabilitiesResponse_Confirmed{
VolumeContext: req.GetVolumeContext(),
VolumeCapabilities: req.GetVolumeCapabilities(),
Parameters: req.GetParameters(),
},
}, nil
}
func (s *service) ControllerGetVolume(
ctx context.Context,
req *csi.ControllerGetVolumeRequest) (
*csi.ControllerGetVolumeResponse, error) {
if hookVal, hookMsg := s.execHook("GetVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
resp := &csi.ControllerGetVolumeResponse{
Status: &csi.ControllerGetVolumeResponse_VolumeStatus{
VolumeCondition: &csi.VolumeCondition{},
},
}
i, v := s.findVolByID(ctx, req.VolumeId)
if i < 0 {
resp.Status.VolumeCondition.Abnormal = true
resp.Status.VolumeCondition.Message = "volume not found"
return resp, status.Error(codes.NotFound, req.VolumeId)
}
resp.Volume = &v
if !s.config.DisableAttach {
resp.Status.PublishedNodeIds = []string{
s.nodeID,
}
}
if hookVal, hookMsg := s.execHook("GetVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return resp, nil
}
func (s *service) ListVolumes(
ctx context.Context,
req *csi.ListVolumesRequest) (
*csi.ListVolumesResponse, error) {
if hookVal, hookMsg := s.execHook("ListVolumesStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
// Copy the mock volumes into a new slice in order to avoid
// locking the service's volume slice for the duration of the
// ListVolumes RPC.
var vols []csi.Volume
func() {
s.volsRWL.RLock()
defer s.volsRWL.RUnlock()
vols = make([]csi.Volume, len(s.vols))
copy(vols, s.vols)
}()
var (
ulenVols = int32(len(vols))
maxEntries = req.MaxEntries
startingToken int32
)
if v := req.StartingToken; v != "" {
i, err := strconv.ParseUint(v, 10, 32)
if err != nil {
return nil, status.Errorf(
codes.Aborted,
"startingToken=%d !< int32=%d",
startingToken, math.MaxUint32)
}
startingToken = int32(i)
}
if startingToken > ulenVols {
return nil, status.Errorf(
codes.Aborted,
"startingToken=%d > len(vols)=%d",
startingToken, ulenVols)
}
// Discern the number of remaining entries.
rem := ulenVols - startingToken
// If maxEntries is 0 or greater than the number of remaining entries then
// set maxEntries to the number of remaining entries.
if maxEntries == 0 || maxEntries > rem {
maxEntries = rem
}
var (
i int
j = startingToken
entries = make(
[]*csi.ListVolumesResponse_Entry,
maxEntries)
)
for i = 0; i < len(entries); i++ {
volumeStatus := &csi.ListVolumesResponse_VolumeStatus{
VolumeCondition: &csi.VolumeCondition{},
}
if !s.config.DisableAttach {
volumeStatus.PublishedNodeIds = []string{
s.nodeID,
}
}
entries[i] = &csi.ListVolumesResponse_Entry{
Volume: &vols[j],
Status: volumeStatus,
}
j++
}
var nextToken string
if n := startingToken + int32(i); n < ulenVols {
nextToken = fmt.Sprintf("%d", n)
}
if hookVal, hookMsg := s.execHook("ListVolumesEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.ListVolumesResponse{
Entries: entries,
NextToken: nextToken,
}, nil
}
func (s *service) GetCapacity(
ctx context.Context,
req *csi.GetCapacityRequest) (
*csi.GetCapacityResponse, error) {
if hookVal, hookMsg := s.execHook("GetCapacity"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.GetCapacityResponse{
AvailableCapacity: MaxStorageCapacity,
}, nil
}
func (s *service) ControllerGetCapabilities(
ctx context.Context,
req *csi.ControllerGetCapabilitiesRequest) (
*csi.ControllerGetCapabilitiesResponse, error) {
if hookVal, hookMsg := s.execHook("ControllerGetCapabilitiesStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
caps := []*csi.ControllerServiceCapability{
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME,
},
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_LIST_VOLUMES,
},
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_LIST_VOLUMES_PUBLISHED_NODES,
},
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_GET_CAPACITY,
},
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_LIST_SNAPSHOTS,
},
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT,
},
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_PUBLISH_READONLY,
},
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_CLONE_VOLUME,
},
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_GET_VOLUME,
},
},
},
{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_VOLUME_CONDITION,
},
},
},
}
if !s.config.DisableAttach {
caps = append(caps, &csi.ControllerServiceCapability{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME,
},
},
})
}
if !s.config.DisableControllerExpansion {
caps = append(caps, &csi.ControllerServiceCapability{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Type: csi.ControllerServiceCapability_RPC_EXPAND_VOLUME,
},
},
})
}
if hookVal, hookMsg := s.execHook("ControllerGetCapabilitiesEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.ControllerGetCapabilitiesResponse{
Capabilities: caps,
}, nil
}
func (s *service) CreateSnapshot(ctx context.Context,
req *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) {
// Check arguments
if len(req.GetName()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Snapshot Name cannot be empty")
}
if len(req.GetSourceVolumeId()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Snapshot SourceVolumeId cannot be empty")
}
if hookVal, hookMsg := s.execHook("CreateSnapshotStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
// Check to see if the snapshot already exists.
if i, v := s.snapshots.FindSnapshot("name", req.GetName()); i >= 0 {
// Requested snapshot name already exists
if v.SnapshotCSI.GetSourceVolumeId() != req.GetSourceVolumeId() || !reflect.DeepEqual(v.Parameters, req.GetParameters()) {
return nil, status.Error(codes.AlreadyExists,
fmt.Sprintf("Snapshot with name %s already exists", req.GetName()))
}
return &csi.CreateSnapshotResponse{Snapshot: &v.SnapshotCSI}, nil
}
// Create the snapshot and add it to the service's in-mem snapshot slice.
snapshot := s.newSnapshot(req.GetName(), req.GetSourceVolumeId(), req.GetParameters())
s.snapshots.Add(snapshot)
if hookVal, hookMsg := s.execHook("CreateSnapshotEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.CreateSnapshotResponse{Snapshot: &snapshot.SnapshotCSI}, nil
}
func (s *service) DeleteSnapshot(ctx context.Context,
req *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) {
// If the snapshot is not specified, return error
if len(req.SnapshotId) == 0 {
return nil, status.Error(codes.InvalidArgument, "Snapshot ID cannot be empty")
}
if hookVal, hookMsg := s.execHook("DeleteSnapshotStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
// If the snapshot does not exist then return an idempotent response.
i, _ := s.snapshots.FindSnapshot("id", req.SnapshotId)
if i < 0 {
return &csi.DeleteSnapshotResponse{}, nil
}
// This delete logic preserves order and prevents potential memory
// leaks. The slice's elements may not be pointers, but the structs
// themselves have fields that are.
s.snapshots.Delete(i)
log.WithField("SnapshotId", req.SnapshotId).Debug("mock delete snapshot")
if hookVal, hookMsg := s.execHook("DeleteSnapshotEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.DeleteSnapshotResponse{}, nil
}
func (s *service) ListSnapshots(ctx context.Context,
req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
if hookVal, hookMsg := s.execHook("ListSnapshots"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
// case 1: SnapshotId is not empty, return snapshots that match the snapshot id.
if len(req.GetSnapshotId()) != 0 {
return getSnapshotById(s, req)
}
// case 2: SourceVolumeId is not empty, return snapshots that match the source volume id.
if len(req.GetSourceVolumeId()) != 0 {
return getSnapshotByVolumeId(s, req)
}
// case 3: no parameter is set, so we return all the snapshots.
return getAllSnapshots(s, req)
}
func (s *service) ControllerExpandVolume(
ctx context.Context,
req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) {
if len(req.VolumeId) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if req.CapacityRange == nil {
return nil, status.Error(codes.InvalidArgument, "Request capacity cannot be empty")
}
if hookVal, hookMsg := s.execHook("ControllerExpandVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
i, v := s.findVolNoLock("id", req.VolumeId)
if i < 0 {
return nil, status.Error(codes.NotFound, req.VolumeId)
}
if s.config.DisableOnlineExpansion && MockVolumes[v.GetVolumeId()].ISControllerPublished {
return nil, status.Error(codes.FailedPrecondition, "volume is published and online volume expansion is not supported")
}
requestBytes := req.CapacityRange.RequiredBytes
if v.CapacityBytes > requestBytes {
return nil, status.Error(codes.InvalidArgument, "cannot change volume capacity to a smaller size")
}
resp := &csi.ControllerExpandVolumeResponse{
CapacityBytes: requestBytes,
NodeExpansionRequired: s.config.NodeExpansionRequired,
}
// Check to see if the volume already satisfied request size.
if v.CapacityBytes == requestBytes {
log.WithField("volumeID", v.VolumeId).Infof("Volume capacity is already %d, no need to expand", requestBytes)
return resp, nil
}
// Update volume's capacity to the requested size.
v.CapacityBytes = requestBytes
s.vols[i] = v
if hookVal, hookMsg := s.execHook("ControllerExpandVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return resp, nil
}
func getSnapshotById(s *service, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
if len(req.GetSnapshotId()) != 0 {
i, snapshot := s.snapshots.FindSnapshot("id", req.GetSnapshotId())
if i < 0 {
return &csi.ListSnapshotsResponse{}, nil
}
if len(req.GetSourceVolumeId()) != 0 {
if snapshot.SnapshotCSI.GetSourceVolumeId() != req.GetSourceVolumeId() {
return &csi.ListSnapshotsResponse{}, nil
}
}
return &csi.ListSnapshotsResponse{
Entries: []*csi.ListSnapshotsResponse_Entry{
{
Snapshot: &snapshot.SnapshotCSI,
},
},
}, nil
}
return nil, nil
}
func getSnapshotByVolumeId(s *service, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
if len(req.GetSourceVolumeId()) != 0 {
i, snapshot := s.snapshots.FindSnapshot("sourceVolumeId", req.SourceVolumeId)
if i < 0 {
return &csi.ListSnapshotsResponse{}, nil
}
return &csi.ListSnapshotsResponse{
Entries: []*csi.ListSnapshotsResponse_Entry{
{
Snapshot: &snapshot.SnapshotCSI,
},
},
}, nil
}
return nil, nil
}
func getAllSnapshots(s *service, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
// Copy the mock snapshots into a new slice in order to avoid
// locking the service's snapshot slice for the duration of the
// ListSnapshots RPC.
readyToUse := true
snapshots := s.snapshots.List(readyToUse)
var (
ulenSnapshots = int32(len(snapshots))
maxEntries = req.MaxEntries
startingToken int32
)
if v := req.StartingToken; v != "" {
i, err := strconv.ParseUint(v, 10, 32)
if err != nil {
return nil, status.Errorf(
codes.Aborted,
"startingToken=%d !< int32=%d",
startingToken, math.MaxUint32)
}
startingToken = int32(i)
}
if startingToken > ulenSnapshots {
return nil, status.Errorf(
codes.Aborted,
"startingToken=%d > len(snapshots)=%d",
startingToken, ulenSnapshots)
}
// Discern the number of remaining entries.
rem := ulenSnapshots - startingToken
// If maxEntries is 0 or greater than the number of remaining entries then
// set maxEntries to the number of remaining entries.
if maxEntries == 0 || maxEntries > rem {
maxEntries = rem
}
var (
i int
j = startingToken
entries = make(
[]*csi.ListSnapshotsResponse_Entry,
maxEntries)
)
for i = 0; i < len(entries); i++ {
entries[i] = &csi.ListSnapshotsResponse_Entry{
Snapshot: &snapshots[j],
}
j++
}
var nextToken string
if n := startingToken + int32(i); n < ulenSnapshots {
nextToken = fmt.Sprintf("%d", n)
}
return &csi.ListSnapshotsResponse{
Entries: entries,
NextToken: nextToken,
}, nil
}

View File

@ -0,0 +1,24 @@
package service
// Predefinded constants for the JavaScript hooks, they must correspond to the
// error codes used by gRPC, see:
// https://github.com/grpc/grpc-go/blob/master/codes/codes.go
const (
grpcJSCodes string = `OK = 0;
CANCELED = 1;
UNKNOWN = 2;
INVALIDARGUMENT = 3;
DEADLINEEXCEEDED = 4;
NOTFOUND = 5;
ALREADYEXISTS = 6;
PERMISSIONDENIED = 7;
RESOURCEEXHAUSTED = 8;
FAILEDPRECONDITION = 9;
ABORTED = 10;
OUTOFRANGE = 11;
UNIMPLEMENTED = 12;
INTERNAL = 13;
UNAVAILABLE = 14;
DATALOSS = 15;
UNAUTHENTICATED = 16`
)

View File

@ -0,0 +1,74 @@
package service
import (
"golang.org/x/net/context"
"github.com/container-storage-interface/spec/lib/go/csi"
"github.com/golang/protobuf/ptypes/wrappers"
)
func (s *service) GetPluginInfo(
ctx context.Context,
req *csi.GetPluginInfoRequest) (
*csi.GetPluginInfoResponse, error) {
return &csi.GetPluginInfoResponse{
Name: s.config.DriverName,
VendorVersion: VendorVersion,
Manifest: Manifest,
}, nil
}
func (s *service) Probe(
ctx context.Context,
req *csi.ProbeRequest) (
*csi.ProbeResponse, error) {
return &csi.ProbeResponse{
Ready: &wrappers.BoolValue{Value: true},
}, nil
}
func (s *service) GetPluginCapabilities(
ctx context.Context,
req *csi.GetPluginCapabilitiesRequest) (
*csi.GetPluginCapabilitiesResponse, error) {
volExpType := csi.PluginCapability_VolumeExpansion_ONLINE
if s.config.DisableOnlineExpansion {
volExpType = csi.PluginCapability_VolumeExpansion_OFFLINE
}
capabilities := []*csi.PluginCapability{
{
Type: &csi.PluginCapability_Service_{
Service: &csi.PluginCapability_Service{
Type: csi.PluginCapability_Service_CONTROLLER_SERVICE,
},
},
},
{
Type: &csi.PluginCapability_VolumeExpansion_{
VolumeExpansion: &csi.PluginCapability_VolumeExpansion{
Type: volExpType,
},
},
},
}
if s.config.EnableTopology {
capabilities = append(capabilities,
&csi.PluginCapability{
Type: &csi.PluginCapability_Service_{
Service: &csi.PluginCapability_Service{
Type: csi.PluginCapability_Service_VOLUME_ACCESSIBILITY_CONSTRAINTS,
},
},
})
}
return &csi.GetPluginCapabilitiesResponse{
Capabilities: capabilities,
}, nil
}

View File

@ -0,0 +1,460 @@
package service
import (
"fmt"
"os"
"path"
"strconv"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"golang.org/x/net/context"
"github.com/container-storage-interface/spec/lib/go/csi"
)
func (s *service) NodeStageVolume(
ctx context.Context,
req *csi.NodeStageVolumeRequest) (
*csi.NodeStageVolumeResponse, error) {
if hookVal, hookMsg := s.execHook("NodeStageVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
device, ok := req.PublishContext["device"]
if !ok {
if s.config.DisableAttach {
device = "mock device"
} else {
return nil, status.Error(
codes.InvalidArgument,
"stage volume info 'device' key required")
}
}
if len(req.GetVolumeId()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if len(req.GetStagingTargetPath()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Staging Target Path cannot be empty")
}
if req.GetVolumeCapability() == nil {
return nil, status.Error(codes.InvalidArgument, "Volume Capability cannot be empty")
}
exists, err := checkTargetExists(req.StagingTargetPath)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
if !exists {
status.Errorf(codes.Internal, "staging target path %s does not exist", req.StagingTargetPath)
}
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
i, v := s.findVolNoLock("id", req.VolumeId)
if i < 0 {
return nil, status.Error(codes.NotFound, req.VolumeId)
}
// nodeStgPathKey is the key in the volume's attributes that is set to a
// mock stage path if the volume has been published by the node
nodeStgPathKey := path.Join(s.nodeID, req.StagingTargetPath)
// Check to see if the volume has already been staged.
if v.VolumeContext[nodeStgPathKey] != "" {
// TODO: Check for the capabilities to be equal. Return "ALREADY_EXISTS"
// if the capabilities don't match.
return &csi.NodeStageVolumeResponse{}, nil
}
// Stage the volume.
v.VolumeContext[nodeStgPathKey] = device
s.vols[i] = v
if hookVal, hookMsg := s.execHook("NodeStageVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.NodeStageVolumeResponse{}, nil
}
func (s *service) NodeUnstageVolume(
ctx context.Context,
req *csi.NodeUnstageVolumeRequest) (
*csi.NodeUnstageVolumeResponse, error) {
if len(req.GetVolumeId()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if len(req.GetStagingTargetPath()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Staging Target Path cannot be empty")
}
if hookVal, hookMsg := s.execHook("NodeUnstageVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
i, v := s.findVolNoLock("id", req.VolumeId)
if i < 0 {
return nil, status.Error(codes.NotFound, req.VolumeId)
}
// nodeStgPathKey is the key in the volume's attributes that is set to a
// mock stage path if the volume has been published by the node
nodeStgPathKey := path.Join(s.nodeID, req.StagingTargetPath)
// Check to see if the volume has already been unstaged.
if v.VolumeContext[nodeStgPathKey] == "" {
return &csi.NodeUnstageVolumeResponse{}, nil
}
// Unpublish the volume.
delete(v.VolumeContext, nodeStgPathKey)
s.vols[i] = v
if hookVal, hookMsg := s.execHook("NodeUnstageVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.NodeUnstageVolumeResponse{}, nil
}
func (s *service) NodePublishVolume(
ctx context.Context,
req *csi.NodePublishVolumeRequest) (
*csi.NodePublishVolumeResponse, error) {
if hookVal, hookMsg := s.execHook("NodePublishVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
ephemeralVolume := req.GetVolumeContext()["csi.storage.k8s.io/ephemeral"] == "true"
device, ok := req.PublishContext["device"]
if !ok {
if ephemeralVolume || s.config.DisableAttach {
device = "mock device"
} else {
return nil, status.Error(
codes.InvalidArgument,
"stage volume info 'device' key required")
}
}
if len(req.GetVolumeId()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if len(req.GetTargetPath()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Target Path cannot be empty")
}
if req.GetVolumeCapability() == nil {
return nil, status.Error(codes.InvalidArgument, "Volume Capability cannot be empty")
}
// May happen with old (or, at this time, even the current) Kubernetes
// although it shouldn't (https://github.com/kubernetes/kubernetes/issues/75535).
exists, err := checkTargetExists(req.TargetPath)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
if !s.config.PermissiveTargetPath && exists {
status.Errorf(codes.Internal, "target path %s does exist", req.TargetPath)
}
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
i, v := s.findVolNoLock("id", req.VolumeId)
if i < 0 && !ephemeralVolume {
return nil, status.Error(codes.NotFound, req.VolumeId)
}
if i >= 0 && ephemeralVolume {
return nil, status.Error(codes.AlreadyExists, req.VolumeId)
}
// nodeMntPathKey is the key in the volume's attributes that is set to a
// mock mount path if the volume has been published by the node
nodeMntPathKey := path.Join(s.nodeID, req.TargetPath)
// Check to see if the volume has already been published.
if v.VolumeContext[nodeMntPathKey] != "" {
// Requests marked Readonly fail due to volumes published by
// the Mock driver supporting only RW mode.
if req.Readonly {
return nil, status.Error(codes.AlreadyExists, req.VolumeId)
}
return &csi.NodePublishVolumeResponse{}, nil
}
// Publish the volume.
if ephemeralVolume {
MockVolumes[req.VolumeId] = Volume{
ISEphemeral: true,
}
} else {
if req.GetTargetPath() != "" {
exists, err := checkTargetExists(req.GetTargetPath())
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
if !exists {
// If target path does not exist we need to create the directory where volume will be staged
if err = os.Mkdir(req.TargetPath, os.FileMode(0755)); err != nil {
msg := fmt.Sprintf("NodePublishVolume: could not create target dir %q: %v", req.TargetPath, err)
return nil, status.Error(codes.Internal, msg)
}
}
v.VolumeContext[nodeMntPathKey] = req.GetTargetPath()
} else {
v.VolumeContext[nodeMntPathKey] = device
}
s.vols[i] = v
}
if hookVal, hookMsg := s.execHook("NodePublishVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.NodePublishVolumeResponse{}, nil
}
func (s *service) NodeUnpublishVolume(
ctx context.Context,
req *csi.NodeUnpublishVolumeRequest) (
*csi.NodeUnpublishVolumeResponse, error) {
if len(req.GetVolumeId()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if len(req.GetTargetPath()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Target Path cannot be empty")
}
if hookVal, hookMsg := s.execHook("NodeUnpublishVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
ephemeralVolume := MockVolumes[req.VolumeId].ISEphemeral
i, v := s.findVolNoLock("id", req.VolumeId)
if i < 0 && !ephemeralVolume {
return nil, status.Error(codes.NotFound, req.VolumeId)
}
if ephemeralVolume {
delete(MockVolumes, req.VolumeId)
} else {
// nodeMntPathKey is the key in the volume's attributes that is set to a
// mock mount path if the volume has been published by the node
nodeMntPathKey := path.Join(s.nodeID, req.TargetPath)
// Check to see if the volume has already been unpublished.
if v.VolumeContext[nodeMntPathKey] == "" {
return &csi.NodeUnpublishVolumeResponse{}, nil
}
// Delete any created paths
err := os.RemoveAll(v.VolumeContext[nodeMntPathKey])
if err != nil {
return nil, status.Errorf(codes.Internal, "Unable to delete previously created target directory")
}
// Unpublish the volume.
delete(v.VolumeContext, nodeMntPathKey)
s.vols[i] = v
}
if hookVal, hookMsg := s.execHook("NodeUnpublishVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return &csi.NodeUnpublishVolumeResponse{}, nil
}
func (s *service) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) {
if len(req.GetVolumeId()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if len(req.GetVolumePath()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume Path cannot be empty")
}
if hookVal, hookMsg := s.execHook("NodeExpandVolumeStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
s.volsRWL.Lock()
defer s.volsRWL.Unlock()
i, v := s.findVolNoLock("id", req.VolumeId)
if i < 0 {
return nil, status.Error(codes.NotFound, req.VolumeId)
}
// TODO: NodeExpandVolume MUST be called after successful NodeStageVolume as we has STAGE_UNSTAGE_VOLUME node capacity.
resp := &csi.NodeExpandVolumeResponse{}
var requestCapacity int64 = 0
if req.GetCapacityRange() != nil {
requestCapacity = req.CapacityRange.GetRequiredBytes()
resp.CapacityBytes = requestCapacity
}
// fsCapacityKey is the key in the volume's attributes that is set to the file system's size.
fsCapacityKey := path.Join(s.nodeID, req.GetVolumePath(), "size")
// Update volume's fs capacity to requested size.
if requestCapacity > 0 {
v.VolumeContext[fsCapacityKey] = strconv.FormatInt(requestCapacity, 10)
s.vols[i] = v
}
if hookVal, hookMsg := s.execHook("NodeExpandVolumeEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
return resp, nil
}
func (s *service) NodeGetCapabilities(
ctx context.Context,
req *csi.NodeGetCapabilitiesRequest) (
*csi.NodeGetCapabilitiesResponse, error) {
if hookVal, hookMsg := s.execHook("NodeGetCapabilities"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
capabilities := []*csi.NodeServiceCapability{
{
Type: &csi.NodeServiceCapability_Rpc{
Rpc: &csi.NodeServiceCapability_RPC{
Type: csi.NodeServiceCapability_RPC_UNKNOWN,
},
},
},
{
Type: &csi.NodeServiceCapability_Rpc{
Rpc: &csi.NodeServiceCapability_RPC{
Type: csi.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME,
},
},
},
{
Type: &csi.NodeServiceCapability_Rpc{
Rpc: &csi.NodeServiceCapability_RPC{
Type: csi.NodeServiceCapability_RPC_GET_VOLUME_STATS,
},
},
},
{
Type: &csi.NodeServiceCapability_Rpc{
Rpc: &csi.NodeServiceCapability_RPC{
Type: csi.NodeServiceCapability_RPC_VOLUME_CONDITION,
},
},
},
}
if s.config.NodeExpansionRequired {
capabilities = append(capabilities, &csi.NodeServiceCapability{
Type: &csi.NodeServiceCapability_Rpc{
Rpc: &csi.NodeServiceCapability_RPC{
Type: csi.NodeServiceCapability_RPC_EXPAND_VOLUME,
},
},
})
}
return &csi.NodeGetCapabilitiesResponse{
Capabilities: capabilities,
}, nil
}
func (s *service) NodeGetInfo(ctx context.Context,
req *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) {
if hookVal, hookMsg := s.execHook("NodeGetInfo"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
csiNodeResponse := &csi.NodeGetInfoResponse{
NodeId: s.nodeID,
}
if s.config.AttachLimit > 0 {
csiNodeResponse.MaxVolumesPerNode = s.config.AttachLimit
}
if s.config.EnableTopology {
csiNodeResponse.AccessibleTopology = &csi.Topology{
Segments: map[string]string{
TopologyKey: TopologyValue,
},
}
}
return csiNodeResponse, nil
}
func (s *service) NodeGetVolumeStats(ctx context.Context,
req *csi.NodeGetVolumeStatsRequest) (*csi.NodeGetVolumeStatsResponse, error) {
if hookVal, hookMsg := s.execHook("NodeGetVolumeStatsStart"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
resp := &csi.NodeGetVolumeStatsResponse{
VolumeCondition: &csi.VolumeCondition{},
}
if len(req.GetVolumeId()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty")
}
if len(req.GetVolumePath()) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume Path cannot be empty")
}
i, v := s.findVolNoLock("id", req.VolumeId)
if i < 0 {
resp.VolumeCondition.Abnormal = true
resp.VolumeCondition.Message = "Volume not found"
return resp, status.Error(codes.NotFound, req.VolumeId)
}
nodeMntPathKey := path.Join(s.nodeID, req.VolumePath)
_, exists := v.VolumeContext[nodeMntPathKey]
if !exists {
msg := fmt.Sprintf("volume %q doest not exist on the specified path %q", req.VolumeId, req.VolumePath)
resp.VolumeCondition.Abnormal = true
resp.VolumeCondition.Message = msg
return resp, status.Errorf(codes.NotFound, msg)
}
if hookVal, hookMsg := s.execHook("NodeGetVolumeStatsEnd"); hookVal != codes.OK {
return nil, status.Errorf(hookVal, hookMsg)
}
resp.Usage = []*csi.VolumeUsage{
{
Total: v.GetCapacityBytes(),
Unit: csi.VolumeUsage_BYTES,
},
}
return resp, nil
}
// checkTargetExists checks if a given path exists.
func checkTargetExists(targetPath string) (bool, error) {
_, err := os.Stat(targetPath)
switch {
case err == nil:
return true, nil
case os.IsNotExist(err):
return false, nil
default:
return false, err
}
}

View File

@ -0,0 +1,293 @@
package service
import (
"fmt"
"reflect"
"strings"
"sync"
"sync/atomic"
"k8s.io/klog"
"github.com/container-storage-interface/spec/lib/go/csi"
"github.com/kubernetes-csi/csi-test/v4/mock/cache"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"github.com/golang/protobuf/ptypes"
"github.com/robertkrimen/otto"
)
const (
// Name is the name of the CSI plug-in.
Name = "io.kubernetes.storage.mock"
// VendorVersion is the version returned by GetPluginInfo.
VendorVersion = "0.3.0"
// TopologyKey simulates a per-node topology.
TopologyKey = Name + "/node"
// TopologyValue is the one, fixed node on which the driver runs.
TopologyValue = "some-mock-node"
)
// Manifest is the SP's manifest.
var Manifest = map[string]string{
"url": "https://github.com/kubernetes-csi/csi-test/mock",
}
// JavaScript hooks to be run to perform various tests
type Hooks struct {
Globals string `yaml:"globals"` // will be executed once before all other scripts
CreateVolumeStart string `yaml:"createVolumeStart"`
CreateVolumeEnd string `yaml:"createVolumeEnd"`
DeleteVolumeStart string `yaml:"deleteVolumeStart"`
DeleteVolumeEnd string `yaml:"deleteVolumeEnd"`
ControllerPublishVolumeStart string `yaml:"controllerPublishVolumeStart"`
ControllerPublishVolumeEnd string `yaml:"controllerPublishVolumeEnd"`
ControllerUnpublishVolumeStart string `yaml:"controllerUnpublishVolumeStart"`
ControllerUnpublishVolumeEnd string `yaml:"controllerUnpublishVolumeEnd"`
ValidateVolumeCapabilities string `yaml:"validateVolumeCapabilities"`
ListVolumesStart string `yaml:"listVolumesStart"`
ListVolumesEnd string `yaml:"listVolumesEnd"`
GetCapacity string `yaml:"getCapacity"`
ControllerGetCapabilitiesStart string `yaml:"controllerGetCapabilitiesStart"`
ControllerGetCapabilitiesEnd string `yaml:"controllerGetCapabilitiesEnd"`
CreateSnapshotStart string `yaml:"createSnapshotStart"`
CreateSnapshotEnd string `yaml:"createSnapshotEnd"`
DeleteSnapshotStart string `yaml:"deleteSnapshotStart"`
DeleteSnapshotEnd string `yaml:"deleteSnapshotEnd"`
ListSnapshots string `yaml:"listSnapshots"`
ControllerExpandVolumeStart string `yaml:"controllerExpandVolumeStart"`
ControllerExpandVolumeEnd string `yaml:"controllerExpandVolumeEnd"`
NodeStageVolumeStart string `yaml:"nodeStageVolumeStart"`
NodeStageVolumeEnd string `yaml:"nodeStageVolumeEnd"`
NodeUnstageVolumeStart string `yaml:"nodeUnstageVolumeStart"`
NodeUnstageVolumeEnd string `yaml:"nodeUnstageVolumeEnd"`
NodePublishVolumeStart string `yaml:"nodePublishVolumeStart"`
NodePublishVolumeEnd string `yaml:"nodePublishVolumeEnd"`
NodeUnpublishVolumeStart string `yaml:"nodeUnpublishVolumeStart"`
NodeUnpublishVolumeEnd string `yaml:"nodeUnpublishVolumeEnd"`
NodeExpandVolumeStart string `yaml:"nodeExpandVolumeStart"`
NodeExpandVolumeEnd string `yaml:"nodeExpandVolumeEnd"`
NodeGetCapabilities string `yaml:"nodeGetCapabilities"`
NodeGetInfo string `yaml:"nodeGetInfo"`
NodeGetVolumeStatsStart string `yaml:"nodeGetVolumeStatsStart"`
NodeGetVolumeStatsEnd string `yaml:"nodeGetVolumeStatsEnd"`
}
type Config struct {
DisableAttach bool
DriverName string
AttachLimit int64
NodeExpansionRequired bool
DisableControllerExpansion bool
DisableOnlineExpansion bool
PermissiveTargetPath bool
EnableTopology bool
ExecHooks *Hooks
}
// Service is the CSI Mock service provider.
type Service interface {
csi.ControllerServer
csi.IdentityServer
csi.NodeServer
}
type service struct {
sync.Mutex
nodeID string
vols []csi.Volume
volsRWL sync.RWMutex
volsNID uint64
snapshots cache.SnapshotCache
snapshotsNID uint64
config Config
hooksVm *otto.Otto
}
type Volume struct {
VolumeCSI csi.Volume
NodeID string
ISStaged bool
ISPublished bool
ISEphemeral bool
ISControllerPublished bool
StageTargetPath string
TargetPath string
}
var MockVolumes map[string]Volume
// New returns a new Service.
func New(config Config) Service {
s := &service{
nodeID: config.DriverName,
config: config,
}
if config.ExecHooks != nil {
s.hooksVm = otto.New()
s.hooksVm.Run(grpcJSCodes) // set global variables with gRPC error codes
_, err := s.hooksVm.Run(s.config.ExecHooks.Globals)
if err != nil {
klog.Exitf("Error encountered in the global exec hook: %v. Exiting\n", err)
}
}
s.snapshots = cache.NewSnapshotCache()
s.vols = []csi.Volume{
s.newVolume("Mock Volume 1", gib100),
s.newVolume("Mock Volume 2", gib100),
s.newVolume("Mock Volume 3", gib100),
}
MockVolumes = map[string]Volume{}
s.snapshots.Add(s.newSnapshot("Mock Snapshot 1", "1", map[string]string{"Description": "snapshot 1"}))
s.snapshots.Add(s.newSnapshot("Mock Snapshot 2", "2", map[string]string{"Description": "snapshot 2"}))
s.snapshots.Add(s.newSnapshot("Mock Snapshot 3", "3", map[string]string{"Description": "snapshot 3"}))
return s
}
const (
kib int64 = 1024
mib int64 = kib * 1024
gib int64 = mib * 1024
gib100 int64 = gib * 100
tib int64 = gib * 1024
tib100 int64 = tib * 100
)
func (s *service) newVolume(name string, capcity int64) csi.Volume {
vol := csi.Volume{
VolumeId: fmt.Sprintf("%d", atomic.AddUint64(&s.volsNID, 1)),
VolumeContext: map[string]string{"name": name},
CapacityBytes: capcity,
}
s.setTopology(&vol)
return vol
}
func (s *service) newVolumeFromSnapshot(name string, capacity int64, snapshotID int) csi.Volume {
vol := s.newVolume(name, capacity)
vol.ContentSource = &csi.VolumeContentSource{
Type: &csi.VolumeContentSource_Snapshot{
Snapshot: &csi.VolumeContentSource_SnapshotSource{
SnapshotId: fmt.Sprintf("%d", snapshotID),
},
},
}
s.setTopology(&vol)
return vol
}
func (s *service) newVolumeFromVolume(name string, capacity int64, volumeID int) csi.Volume {
vol := s.newVolume(name, capacity)
vol.ContentSource = &csi.VolumeContentSource{
Type: &csi.VolumeContentSource_Volume{
Volume: &csi.VolumeContentSource_VolumeSource{
VolumeId: fmt.Sprintf("%d", volumeID),
},
},
}
s.setTopology(&vol)
return vol
}
func (s *service) setTopology(vol *csi.Volume) {
if s.config.EnableTopology {
vol.AccessibleTopology = []*csi.Topology{
&csi.Topology{
Segments: map[string]string{
TopologyKey: TopologyValue,
},
},
}
}
}
func (s *service) findVol(k, v string) (volIdx int, volInfo csi.Volume) {
s.volsRWL.RLock()
defer s.volsRWL.RUnlock()
return s.findVolNoLock(k, v)
}
func (s *service) findVolNoLock(k, v string) (volIdx int, volInfo csi.Volume) {
volIdx = -1
for i, vi := range s.vols {
switch k {
case "id":
if strings.EqualFold(v, vi.GetVolumeId()) {
return i, vi
}
case "name":
if n, ok := vi.VolumeContext["name"]; ok && strings.EqualFold(v, n) {
return i, vi
}
}
}
return
}
func (s *service) findVolByName(
ctx context.Context, name string) (int, csi.Volume) {
return s.findVol("name", name)
}
func (s *service) findVolByID(
ctx context.Context, id string) (int, csi.Volume) {
return s.findVol("id", id)
}
func (s *service) newSnapshot(name, sourceVolumeId string, parameters map[string]string) cache.Snapshot {
ptime := ptypes.TimestampNow()
return cache.Snapshot{
Name: name,
Parameters: parameters,
SnapshotCSI: csi.Snapshot{
SnapshotId: fmt.Sprintf("%d", atomic.AddUint64(&s.snapshotsNID, 1)),
CreationTime: ptime,
SourceVolumeId: sourceVolumeId,
ReadyToUse: true,
},
}
}
// getAttachCount returns the number of attached volumes on the node.
func (s *service) getAttachCount(devPathKey string) int64 {
var count int64
for _, v := range s.vols {
if device := v.VolumeContext[devPathKey]; device != "" {
count++
}
}
return count
}
func (s *service) execHook(hookName string) (codes.Code, string) {
if s.hooksVm != nil {
script := reflect.ValueOf(*s.config.ExecHooks).FieldByName(hookName).String()
if len(script) > 0 {
result, err := s.hooksVm.Run(script)
if err != nil {
klog.Exitf("Exec hook %s error: %v; exiting\n", hookName, err)
}
rv, err := result.ToInteger()
if err == nil {
// Function returned an integer, use it
return codes.Code(rv), fmt.Sprintf("Exec hook %s returned non-OK code", hookName)
} else {
// Function returned non-integer data type, discard it
return codes.OK, ""
}
}
}
return codes.OK, ""
}