diff --git a/test/e2e/storage/drivers/csi-test/driver/driver-controller.go b/test/e2e/storage/drivers/csi-test/driver/driver-controller.go new file mode 100644 index 00000000000..1d8d2bd771e --- /dev/null +++ b/test/e2e/storage/drivers/csi-test/driver/driver-controller.go @@ -0,0 +1,110 @@ +/* +Copyright 2019 Kubernetes Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package driver + +import ( + "context" + "net" + "sync" + + "google.golang.org/grpc/reflection" + + csi "github.com/container-storage-interface/spec/lib/go/csi" + "google.golang.org/grpc" +) + +// CSIDriverControllerServer is the Controller service component of the driver. +type CSIDriverControllerServer struct { + Controller csi.ControllerServer + Identity csi.IdentityServer +} + +// CSIDriverController is the CSI Driver Controller backend. +type CSIDriverController struct { + listener net.Listener + server *grpc.Server + controllerServer *CSIDriverControllerServer + wg sync.WaitGroup + running bool + lock sync.Mutex + creds *CSICreds +} + +func NewCSIDriverController(controllerServer *CSIDriverControllerServer) *CSIDriverController { + return &CSIDriverController{ + controllerServer: controllerServer, + } +} + +func (c *CSIDriverController) goServe(started chan<- bool) { + goServe(c.server, &c.wg, c.listener, started) +} + +func (c *CSIDriverController) Address() string { + return c.listener.Addr().String() +} + +func (c *CSIDriverController) Start(l net.Listener) error { + c.lock.Lock() + defer c.lock.Unlock() + + // Set listener. + c.listener = l + + // Create a new grpc server. + c.server = grpc.NewServer( + grpc.UnaryInterceptor(c.callInterceptor), + ) + + if c.controllerServer.Controller != nil { + csi.RegisterControllerServer(c.server, c.controllerServer.Controller) + } + if c.controllerServer.Identity != nil { + csi.RegisterIdentityServer(c.server, c.controllerServer.Identity) + } + + reflection.Register(c.server) + + waitForServer := make(chan bool) + c.goServe(waitForServer) + <-waitForServer + c.running = true + return nil +} + +func (c *CSIDriverController) Stop() { + stop(&c.lock, &c.wg, c.server, c.running) +} + +func (c *CSIDriverController) Close() { + c.server.Stop() +} + +func (c *CSIDriverController) IsRunning() bool { + c.lock.Lock() + defer c.lock.Unlock() + + return c.running +} + +func (c *CSIDriverController) SetDefaultCreds() { + setDefaultCreds(c.creds) +} + +func (c *CSIDriverController) callInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + return callInterceptor(ctx, c.creds, req, info, handler) +} diff --git a/test/e2e/storage/drivers/csi-test/driver/driver-node.go b/test/e2e/storage/drivers/csi-test/driver/driver-node.go new file mode 100644 index 00000000000..7720bfc493a --- /dev/null +++ b/test/e2e/storage/drivers/csi-test/driver/driver-node.go @@ -0,0 +1,109 @@ +/* +Copyright 2019 Kubernetes Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package driver + +import ( + context "context" + "net" + "sync" + + csi "github.com/container-storage-interface/spec/lib/go/csi" + "google.golang.org/grpc" + "google.golang.org/grpc/reflection" +) + +// CSIDriverNodeServer is the Node service component of the driver. +type CSIDriverNodeServer struct { + Node csi.NodeServer + Identity csi.IdentityServer +} + +// CSIDriverNode is the CSI Driver Node backend. +type CSIDriverNode struct { + listener net.Listener + server *grpc.Server + nodeServer *CSIDriverNodeServer + wg sync.WaitGroup + running bool + lock sync.Mutex + creds *CSICreds +} + +func NewCSIDriverNode(nodeServer *CSIDriverNodeServer) *CSIDriverNode { + return &CSIDriverNode{ + nodeServer: nodeServer, + } +} + +func (c *CSIDriverNode) goServe(started chan<- bool) { + goServe(c.server, &c.wg, c.listener, started) +} + +func (c *CSIDriverNode) Address() string { + return c.listener.Addr().String() +} + +func (c *CSIDriverNode) Start(l net.Listener) error { + c.lock.Lock() + defer c.lock.Unlock() + + // Set listener. + c.listener = l + + // Create a new grpc server. + c.server = grpc.NewServer( + grpc.UnaryInterceptor(c.callInterceptor), + ) + + if c.nodeServer.Node != nil { + csi.RegisterNodeServer(c.server, c.nodeServer.Node) + } + if c.nodeServer.Identity != nil { + csi.RegisterIdentityServer(c.server, c.nodeServer.Identity) + } + + reflection.Register(c.server) + + waitForServer := make(chan bool) + c.goServe(waitForServer) + <-waitForServer + c.running = true + return nil +} + +func (c *CSIDriverNode) Stop() { + stop(&c.lock, &c.wg, c.server, c.running) +} + +func (c *CSIDriverNode) Close() { + c.server.Stop() +} + +func (c *CSIDriverNode) IsRunning() bool { + c.lock.Lock() + defer c.lock.Unlock() + + return c.running +} + +func (c *CSIDriverNode) SetDefaultCreds() { + setDefaultCreds(c.creds) +} + +func (c *CSIDriverNode) callInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + return callInterceptor(ctx, c.creds, req, info, handler) +} diff --git a/test/e2e/storage/drivers/csi-test/driver/driver.go b/test/e2e/storage/drivers/csi-test/driver/driver.go new file mode 100644 index 00000000000..33ffe99359d --- /dev/null +++ b/test/e2e/storage/drivers/csi-test/driver/driver.go @@ -0,0 +1,312 @@ +/* +Copyright 2017 Luis Pabón luis@portworx.com + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//go:generate mockgen -package=driver -destination=driver.mock.go github.com/container-storage-interface/spec/lib/go/csi IdentityServer,ControllerServer,NodeServer + +package driver + +import ( + "context" + "encoding/json" + "errors" + "net" + "sync" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "k8s.io/klog" + + "github.com/container-storage-interface/spec/lib/go/csi" + "google.golang.org/grpc" + "google.golang.org/grpc/reflection" +) + +var ( + // ErrNoCredentials is the error when a secret is enabled but not passed in the request. + ErrNoCredentials = errors.New("secret must be provided") + // ErrAuthFailed is the error when the secret is incorrect. + ErrAuthFailed = errors.New("authentication failed") +) + +// CSIDriverServers is a unified driver component with both Controller and Node +// services. +type CSIDriverServers struct { + Controller csi.ControllerServer + Identity csi.IdentityServer + Node csi.NodeServer +} + +// This is the key name in all the CSI secret objects. +const secretField = "secretKey" + +// CSICreds is a driver specific secret type. Drivers can have a key-val pair of +// secrets. This mock driver has a single string secret with secretField as the +// key. +type CSICreds struct { + CreateVolumeSecret string + DeleteVolumeSecret string + ControllerPublishVolumeSecret string + ControllerUnpublishVolumeSecret string + NodeStageVolumeSecret string + NodePublishVolumeSecret string + CreateSnapshotSecret string + DeleteSnapshotSecret string + ControllerValidateVolumeCapabilitiesSecret string +} + +type CSIDriver struct { + listener net.Listener + server *grpc.Server + servers *CSIDriverServers + wg sync.WaitGroup + running bool + lock sync.Mutex + creds *CSICreds +} + +func NewCSIDriver(servers *CSIDriverServers) *CSIDriver { + return &CSIDriver{ + servers: servers, + } +} + +func (c *CSIDriver) goServe(started chan<- bool) { + goServe(c.server, &c.wg, c.listener, started) +} + +func (c *CSIDriver) Address() string { + return c.listener.Addr().String() +} +func (c *CSIDriver) Start(l net.Listener) error { + c.lock.Lock() + defer c.lock.Unlock() + + // Set listener + c.listener = l + + // Create a new grpc server + c.server = grpc.NewServer( + grpc.UnaryInterceptor(c.callInterceptor), + ) + + // Register Mock servers + if c.servers.Controller != nil { + csi.RegisterControllerServer(c.server, c.servers.Controller) + } + if c.servers.Identity != nil { + csi.RegisterIdentityServer(c.server, c.servers.Identity) + } + if c.servers.Node != nil { + csi.RegisterNodeServer(c.server, c.servers.Node) + } + reflection.Register(c.server) + + // Start listening for requests + waitForServer := make(chan bool) + c.goServe(waitForServer) + <-waitForServer + c.running = true + return nil +} + +func (c *CSIDriver) Stop() { + stop(&c.lock, &c.wg, c.server, c.running) +} + +func (c *CSIDriver) Close() { + c.server.Stop() +} + +func (c *CSIDriver) IsRunning() bool { + c.lock.Lock() + defer c.lock.Unlock() + + return c.running +} + +// SetDefaultCreds sets the default secrets for CSI creds. +func (c *CSIDriver) SetDefaultCreds() { + setDefaultCreds(c.creds) +} + +func (c *CSIDriver) callInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + return callInterceptor(ctx, c.creds, req, info, handler) +} + +// goServe starts a grpc server. +func goServe(server *grpc.Server, wg *sync.WaitGroup, listener net.Listener, started chan<- bool) { + wg.Add(1) + go func() { + defer wg.Done() + started <- true + err := server.Serve(listener) + if err != nil { + panic(err.Error()) + } + }() +} + +// stop stops a grpc server. +func stop(lock *sync.Mutex, wg *sync.WaitGroup, server *grpc.Server, running bool) { + lock.Lock() + defer lock.Unlock() + + if !running { + return + } + + server.Stop() + wg.Wait() +} + +// setDefaultCreds sets the default credentials, given a CSICreds instance. +func setDefaultCreds(creds *CSICreds) { + creds = &CSICreds{ + CreateVolumeSecret: "secretval1", + DeleteVolumeSecret: "secretval2", + ControllerPublishVolumeSecret: "secretval3", + ControllerUnpublishVolumeSecret: "secretval4", + NodeStageVolumeSecret: "secretval5", + NodePublishVolumeSecret: "secretval6", + CreateSnapshotSecret: "secretval7", + DeleteSnapshotSecret: "secretval8", + ControllerValidateVolumeCapabilitiesSecret: "secretval9", + } +} + +func callInterceptor(ctx context.Context, creds *CSICreds, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + err := authInterceptor(creds, req) + if err != nil { + logGRPC(info.FullMethod, req, nil, err) + return nil, err + } + rsp, err := handler(ctx, req) + logGRPC(info.FullMethod, req, rsp, err) + return rsp, err +} + +func authInterceptor(creds *CSICreds, req interface{}) error { + if creds != nil { + authenticated, authErr := isAuthenticated(req, creds) + if !authenticated { + if authErr == ErrNoCredentials { + return status.Error(codes.InvalidArgument, authErr.Error()) + } + if authErr == ErrAuthFailed { + return status.Error(codes.Unauthenticated, authErr.Error()) + } + } + } + return nil +} + +func logGRPC(method string, request, reply interface{}, err error) { + // Log JSON with the request and response for easier parsing + logMessage := struct { + Method string + Request interface{} + Response interface{} + // Error as string, for backward compatibility. + // "" on no error. + Error string + // Full error dump, to be able to parse out full gRPC error code and message separately in a test. + FullError error + }{ + Method: method, + Request: request, + Response: reply, + FullError: err, + } + + if err != nil { + logMessage.Error = err.Error() + } + + msg, _ := json.Marshal(logMessage) + klog.V(3).Infof("gRPCCall: %s\n", msg) +} + +func isAuthenticated(req interface{}, creds *CSICreds) (bool, error) { + switch r := req.(type) { + case *csi.CreateVolumeRequest: + return authenticateCreateVolume(r, creds) + case *csi.DeleteVolumeRequest: + return authenticateDeleteVolume(r, creds) + case *csi.ControllerPublishVolumeRequest: + return authenticateControllerPublishVolume(r, creds) + case *csi.ControllerUnpublishVolumeRequest: + return authenticateControllerUnpublishVolume(r, creds) + case *csi.NodeStageVolumeRequest: + return authenticateNodeStageVolume(r, creds) + case *csi.NodePublishVolumeRequest: + return authenticateNodePublishVolume(r, creds) + case *csi.CreateSnapshotRequest: + return authenticateCreateSnapshot(r, creds) + case *csi.DeleteSnapshotRequest: + return authenticateDeleteSnapshot(r, creds) + case *csi.ValidateVolumeCapabilitiesRequest: + return authenticateControllerValidateVolumeCapabilities(r, creds) + default: + return true, nil + } +} + +func authenticateCreateVolume(req *csi.CreateVolumeRequest, creds *CSICreds) (bool, error) { + return credsCheck(req.GetSecrets(), creds.CreateVolumeSecret) +} + +func authenticateDeleteVolume(req *csi.DeleteVolumeRequest, creds *CSICreds) (bool, error) { + return credsCheck(req.GetSecrets(), creds.DeleteVolumeSecret) +} + +func authenticateControllerPublishVolume(req *csi.ControllerPublishVolumeRequest, creds *CSICreds) (bool, error) { + return credsCheck(req.GetSecrets(), creds.ControllerPublishVolumeSecret) +} + +func authenticateControllerUnpublishVolume(req *csi.ControllerUnpublishVolumeRequest, creds *CSICreds) (bool, error) { + return credsCheck(req.GetSecrets(), creds.ControllerUnpublishVolumeSecret) +} + +func authenticateNodeStageVolume(req *csi.NodeStageVolumeRequest, creds *CSICreds) (bool, error) { + return credsCheck(req.GetSecrets(), creds.NodeStageVolumeSecret) +} + +func authenticateNodePublishVolume(req *csi.NodePublishVolumeRequest, creds *CSICreds) (bool, error) { + return credsCheck(req.GetSecrets(), creds.NodePublishVolumeSecret) +} + +func authenticateCreateSnapshot(req *csi.CreateSnapshotRequest, creds *CSICreds) (bool, error) { + return credsCheck(req.GetSecrets(), creds.CreateSnapshotSecret) +} + +func authenticateDeleteSnapshot(req *csi.DeleteSnapshotRequest, creds *CSICreds) (bool, error) { + return credsCheck(req.GetSecrets(), creds.DeleteSnapshotSecret) +} + +func authenticateControllerValidateVolumeCapabilities(req *csi.ValidateVolumeCapabilitiesRequest, creds *CSICreds) (bool, error) { + return credsCheck(req.GetSecrets(), creds.ControllerValidateVolumeCapabilitiesSecret) +} + +func credsCheck(secrets map[string]string, secretVal string) (bool, error) { + if len(secrets) == 0 { + return false, ErrNoCredentials + } + + if secrets[secretField] != secretVal { + return false, ErrAuthFailed + } + return true, nil +} diff --git a/test/e2e/storage/drivers/csi-test/driver/driver.mock.go b/test/e2e/storage/drivers/csi-test/driver/driver.mock.go new file mode 100644 index 00000000000..7eeaca0f022 --- /dev/null +++ b/test/e2e/storage/drivers/csi-test/driver/driver.mock.go @@ -0,0 +1,392 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/container-storage-interface/spec/lib/go/csi (interfaces: IdentityServer,ControllerServer,NodeServer) + +// Package driver is a generated GoMock package. +package driver + +import ( + context "context" + csi "github.com/container-storage-interface/spec/lib/go/csi" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockIdentityServer is a mock of IdentityServer interface +type MockIdentityServer struct { + ctrl *gomock.Controller + recorder *MockIdentityServerMockRecorder +} + +// MockIdentityServerMockRecorder is the mock recorder for MockIdentityServer +type MockIdentityServerMockRecorder struct { + mock *MockIdentityServer +} + +// NewMockIdentityServer creates a new mock instance +func NewMockIdentityServer(ctrl *gomock.Controller) *MockIdentityServer { + mock := &MockIdentityServer{ctrl: ctrl} + mock.recorder = &MockIdentityServerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockIdentityServer) EXPECT() *MockIdentityServerMockRecorder { + return m.recorder +} + +// GetPluginCapabilities mocks base method +func (m *MockIdentityServer) GetPluginCapabilities(arg0 context.Context, arg1 *csi.GetPluginCapabilitiesRequest) (*csi.GetPluginCapabilitiesResponse, error) { + ret := m.ctrl.Call(m, "GetPluginCapabilities", arg0, arg1) + ret0, _ := ret[0].(*csi.GetPluginCapabilitiesResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPluginCapabilities indicates an expected call of GetPluginCapabilities +func (mr *MockIdentityServerMockRecorder) GetPluginCapabilities(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPluginCapabilities", reflect.TypeOf((*MockIdentityServer)(nil).GetPluginCapabilities), arg0, arg1) +} + +// GetPluginInfo mocks base method +func (m *MockIdentityServer) GetPluginInfo(arg0 context.Context, arg1 *csi.GetPluginInfoRequest) (*csi.GetPluginInfoResponse, error) { + ret := m.ctrl.Call(m, "GetPluginInfo", arg0, arg1) + ret0, _ := ret[0].(*csi.GetPluginInfoResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPluginInfo indicates an expected call of GetPluginInfo +func (mr *MockIdentityServerMockRecorder) GetPluginInfo(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPluginInfo", reflect.TypeOf((*MockIdentityServer)(nil).GetPluginInfo), arg0, arg1) +} + +// Probe mocks base method +func (m *MockIdentityServer) Probe(arg0 context.Context, arg1 *csi.ProbeRequest) (*csi.ProbeResponse, error) { + ret := m.ctrl.Call(m, "Probe", arg0, arg1) + ret0, _ := ret[0].(*csi.ProbeResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Probe indicates an expected call of Probe +func (mr *MockIdentityServerMockRecorder) Probe(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Probe", reflect.TypeOf((*MockIdentityServer)(nil).Probe), arg0, arg1) +} + +// MockControllerServer is a mock of ControllerServer interface +type MockControllerServer struct { + ctrl *gomock.Controller + recorder *MockControllerServerMockRecorder +} + +// MockControllerServerMockRecorder is the mock recorder for MockControllerServer +type MockControllerServerMockRecorder struct { + mock *MockControllerServer +} + +// NewMockControllerServer creates a new mock instance +func NewMockControllerServer(ctrl *gomock.Controller) *MockControllerServer { + mock := &MockControllerServer{ctrl: ctrl} + mock.recorder = &MockControllerServerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockControllerServer) EXPECT() *MockControllerServerMockRecorder { + return m.recorder +} + +// ControllerExpandVolume mocks base method +func (m *MockControllerServer) ControllerExpandVolume(arg0 context.Context, arg1 *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) { + ret := m.ctrl.Call(m, "ControllerExpandVolume", arg0, arg1) + ret0, _ := ret[0].(*csi.ControllerExpandVolumeResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ControllerExpandVolume indicates an expected call of ControllerExpandVolume +func (mr *MockControllerServerMockRecorder) ControllerExpandVolume(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ControllerExpandVolume", reflect.TypeOf((*MockControllerServer)(nil).ControllerExpandVolume), arg0, arg1) +} + +// ControllerGetCapabilities mocks base method +func (m *MockControllerServer) ControllerGetCapabilities(arg0 context.Context, arg1 *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) { + ret := m.ctrl.Call(m, "ControllerGetCapabilities", arg0, arg1) + ret0, _ := ret[0].(*csi.ControllerGetCapabilitiesResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ControllerGetCapabilities indicates an expected call of ControllerGetCapabilities +func (mr *MockControllerServerMockRecorder) ControllerGetCapabilities(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ControllerGetCapabilities", reflect.TypeOf((*MockControllerServer)(nil).ControllerGetCapabilities), arg0, arg1) +} + +// ControllerPublishVolume mocks base method +func (m *MockControllerServer) ControllerPublishVolume(arg0 context.Context, arg1 *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) { + ret := m.ctrl.Call(m, "ControllerPublishVolume", arg0, arg1) + ret0, _ := ret[0].(*csi.ControllerPublishVolumeResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ControllerPublishVolume indicates an expected call of ControllerPublishVolume +func (mr *MockControllerServerMockRecorder) ControllerPublishVolume(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ControllerPublishVolume", reflect.TypeOf((*MockControllerServer)(nil).ControllerPublishVolume), arg0, arg1) +} + +// ControllerUnpublishVolume mocks base method +func (m *MockControllerServer) ControllerUnpublishVolume(arg0 context.Context, arg1 *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) { + ret := m.ctrl.Call(m, "ControllerUnpublishVolume", arg0, arg1) + ret0, _ := ret[0].(*csi.ControllerUnpublishVolumeResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ControllerUnpublishVolume indicates an expected call of ControllerUnpublishVolume +func (mr *MockControllerServerMockRecorder) ControllerUnpublishVolume(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ControllerUnpublishVolume", reflect.TypeOf((*MockControllerServer)(nil).ControllerUnpublishVolume), arg0, arg1) +} + +// CreateSnapshot mocks base method +func (m *MockControllerServer) CreateSnapshot(arg0 context.Context, arg1 *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) { + ret := m.ctrl.Call(m, "CreateSnapshot", arg0, arg1) + ret0, _ := ret[0].(*csi.CreateSnapshotResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateSnapshot indicates an expected call of CreateSnapshot +func (mr *MockControllerServerMockRecorder) CreateSnapshot(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSnapshot", reflect.TypeOf((*MockControllerServer)(nil).CreateSnapshot), arg0, arg1) +} + +// CreateVolume mocks base method +func (m *MockControllerServer) CreateVolume(arg0 context.Context, arg1 *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) { + ret := m.ctrl.Call(m, "CreateVolume", arg0, arg1) + ret0, _ := ret[0].(*csi.CreateVolumeResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateVolume indicates an expected call of CreateVolume +func (mr *MockControllerServerMockRecorder) CreateVolume(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateVolume", reflect.TypeOf((*MockControllerServer)(nil).CreateVolume), arg0, arg1) +} + +// DeleteSnapshot mocks base method +func (m *MockControllerServer) DeleteSnapshot(arg0 context.Context, arg1 *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) { + ret := m.ctrl.Call(m, "DeleteSnapshot", arg0, arg1) + ret0, _ := ret[0].(*csi.DeleteSnapshotResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteSnapshot indicates an expected call of DeleteSnapshot +func (mr *MockControllerServerMockRecorder) DeleteSnapshot(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSnapshot", reflect.TypeOf((*MockControllerServer)(nil).DeleteSnapshot), arg0, arg1) +} + +// DeleteVolume mocks base method +func (m *MockControllerServer) DeleteVolume(arg0 context.Context, arg1 *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) { + ret := m.ctrl.Call(m, "DeleteVolume", arg0, arg1) + ret0, _ := ret[0].(*csi.DeleteVolumeResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteVolume indicates an expected call of DeleteVolume +func (mr *MockControllerServerMockRecorder) DeleteVolume(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteVolume", reflect.TypeOf((*MockControllerServer)(nil).DeleteVolume), arg0, arg1) +} + +// GetCapacity mocks base method +func (m *MockControllerServer) GetCapacity(arg0 context.Context, arg1 *csi.GetCapacityRequest) (*csi.GetCapacityResponse, error) { + ret := m.ctrl.Call(m, "GetCapacity", arg0, arg1) + ret0, _ := ret[0].(*csi.GetCapacityResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCapacity indicates an expected call of GetCapacity +func (mr *MockControllerServerMockRecorder) GetCapacity(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCapacity", reflect.TypeOf((*MockControllerServer)(nil).GetCapacity), arg0, arg1) +} + +// ListSnapshots mocks base method +func (m *MockControllerServer) ListSnapshots(arg0 context.Context, arg1 *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) { + ret := m.ctrl.Call(m, "ListSnapshots", arg0, arg1) + ret0, _ := ret[0].(*csi.ListSnapshotsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListSnapshots indicates an expected call of ListSnapshots +func (mr *MockControllerServerMockRecorder) ListSnapshots(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListSnapshots", reflect.TypeOf((*MockControllerServer)(nil).ListSnapshots), arg0, arg1) +} + +// ListVolumes mocks base method +func (m *MockControllerServer) ListVolumes(arg0 context.Context, arg1 *csi.ListVolumesRequest) (*csi.ListVolumesResponse, error) { + ret := m.ctrl.Call(m, "ListVolumes", arg0, arg1) + ret0, _ := ret[0].(*csi.ListVolumesResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (m *MockControllerServer) ControllerGetVolume(arg0 context.Context, arg1 *csi.ControllerGetVolumeRequest) (*csi.ControllerGetVolumeResponse, error) { + ret := m.ctrl.Call(m, "ControllerGetVolume", arg0, arg1) + ret0, _ := ret[0].(*csi.ControllerGetVolumeResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ControllerGetVolume indicates an expected call of ControllerGetVolume +func (mr *MockControllerServerMockRecorder) ControllerGetVolume(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ControllerGetVolume", reflect.TypeOf((*MockControllerServer)(nil).ControllerGetVolume), arg0, arg1) +} + +// ListVolumes indicates an expected call of ListVolumes +func (mr *MockControllerServerMockRecorder) ListVolumes(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListVolumes", reflect.TypeOf((*MockControllerServer)(nil).ListVolumes), arg0, arg1) +} + +// ValidateVolumeCapabilities mocks base method +func (m *MockControllerServer) ValidateVolumeCapabilities(arg0 context.Context, arg1 *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) { + ret := m.ctrl.Call(m, "ValidateVolumeCapabilities", arg0, arg1) + ret0, _ := ret[0].(*csi.ValidateVolumeCapabilitiesResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ValidateVolumeCapabilities indicates an expected call of ValidateVolumeCapabilities +func (mr *MockControllerServerMockRecorder) ValidateVolumeCapabilities(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateVolumeCapabilities", reflect.TypeOf((*MockControllerServer)(nil).ValidateVolumeCapabilities), arg0, arg1) +} + +// MockNodeServer is a mock of NodeServer interface +type MockNodeServer struct { + ctrl *gomock.Controller + recorder *MockNodeServerMockRecorder +} + +// MockNodeServerMockRecorder is the mock recorder for MockNodeServer +type MockNodeServerMockRecorder struct { + mock *MockNodeServer +} + +// NewMockNodeServer creates a new mock instance +func NewMockNodeServer(ctrl *gomock.Controller) *MockNodeServer { + mock := &MockNodeServer{ctrl: ctrl} + mock.recorder = &MockNodeServerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockNodeServer) EXPECT() *MockNodeServerMockRecorder { + return m.recorder +} + +// NodeExpandVolume mocks base method +func (m *MockNodeServer) NodeExpandVolume(arg0 context.Context, arg1 *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) { + ret := m.ctrl.Call(m, "NodeExpandVolume", arg0, arg1) + ret0, _ := ret[0].(*csi.NodeExpandVolumeResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NodeExpandVolume indicates an expected call of NodeExpandVolume +func (mr *MockNodeServerMockRecorder) NodeExpandVolume(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeExpandVolume", reflect.TypeOf((*MockNodeServer)(nil).NodeExpandVolume), arg0, arg1) +} + +// NodeGetCapabilities mocks base method +func (m *MockNodeServer) NodeGetCapabilities(arg0 context.Context, arg1 *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) { + ret := m.ctrl.Call(m, "NodeGetCapabilities", arg0, arg1) + ret0, _ := ret[0].(*csi.NodeGetCapabilitiesResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NodeGetCapabilities indicates an expected call of NodeGetCapabilities +func (mr *MockNodeServerMockRecorder) NodeGetCapabilities(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeGetCapabilities", reflect.TypeOf((*MockNodeServer)(nil).NodeGetCapabilities), arg0, arg1) +} + +// NodeGetInfo mocks base method +func (m *MockNodeServer) NodeGetInfo(arg0 context.Context, arg1 *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) { + ret := m.ctrl.Call(m, "NodeGetInfo", arg0, arg1) + ret0, _ := ret[0].(*csi.NodeGetInfoResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NodeGetInfo indicates an expected call of NodeGetInfo +func (mr *MockNodeServerMockRecorder) NodeGetInfo(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeGetInfo", reflect.TypeOf((*MockNodeServer)(nil).NodeGetInfo), arg0, arg1) +} + +// NodeGetVolumeStats mocks base method +func (m *MockNodeServer) NodeGetVolumeStats(arg0 context.Context, arg1 *csi.NodeGetVolumeStatsRequest) (*csi.NodeGetVolumeStatsResponse, error) { + ret := m.ctrl.Call(m, "NodeGetVolumeStats", arg0, arg1) + ret0, _ := ret[0].(*csi.NodeGetVolumeStatsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NodeGetVolumeStats indicates an expected call of NodeGetVolumeStats +func (mr *MockNodeServerMockRecorder) NodeGetVolumeStats(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeGetVolumeStats", reflect.TypeOf((*MockNodeServer)(nil).NodeGetVolumeStats), arg0, arg1) +} + +// NodePublishVolume mocks base method +func (m *MockNodeServer) NodePublishVolume(arg0 context.Context, arg1 *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) { + ret := m.ctrl.Call(m, "NodePublishVolume", arg0, arg1) + ret0, _ := ret[0].(*csi.NodePublishVolumeResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NodePublishVolume indicates an expected call of NodePublishVolume +func (mr *MockNodeServerMockRecorder) NodePublishVolume(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodePublishVolume", reflect.TypeOf((*MockNodeServer)(nil).NodePublishVolume), arg0, arg1) +} + +// NodeStageVolume mocks base method +func (m *MockNodeServer) NodeStageVolume(arg0 context.Context, arg1 *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) { + ret := m.ctrl.Call(m, "NodeStageVolume", arg0, arg1) + ret0, _ := ret[0].(*csi.NodeStageVolumeResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NodeStageVolume indicates an expected call of NodeStageVolume +func (mr *MockNodeServerMockRecorder) NodeStageVolume(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeStageVolume", reflect.TypeOf((*MockNodeServer)(nil).NodeStageVolume), arg0, arg1) +} + +// NodeUnpublishVolume mocks base method +func (m *MockNodeServer) NodeUnpublishVolume(arg0 context.Context, arg1 *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) { + ret := m.ctrl.Call(m, "NodeUnpublishVolume", arg0, arg1) + ret0, _ := ret[0].(*csi.NodeUnpublishVolumeResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NodeUnpublishVolume indicates an expected call of NodeUnpublishVolume +func (mr *MockNodeServerMockRecorder) NodeUnpublishVolume(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeUnpublishVolume", reflect.TypeOf((*MockNodeServer)(nil).NodeUnpublishVolume), arg0, arg1) +} + +// NodeUnstageVolume mocks base method +func (m *MockNodeServer) NodeUnstageVolume(arg0 context.Context, arg1 *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) { + ret := m.ctrl.Call(m, "NodeUnstageVolume", arg0, arg1) + ret0, _ := ret[0].(*csi.NodeUnstageVolumeResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NodeUnstageVolume indicates an expected call of NodeUnstageVolume +func (mr *MockNodeServerMockRecorder) NodeUnstageVolume(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NodeUnstageVolume", reflect.TypeOf((*MockNodeServer)(nil).NodeUnstageVolume), arg0, arg1) +} diff --git a/test/e2e/storage/drivers/csi-test/driver/mock.go b/test/e2e/storage/drivers/csi-test/driver/mock.go new file mode 100644 index 00000000000..7e2b5020104 --- /dev/null +++ b/test/e2e/storage/drivers/csi-test/driver/mock.go @@ -0,0 +1,89 @@ +/* +Copyright 2017 Luis Pabón luis@portworx.com + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package driver + +import ( + "net" + + "github.com/kubernetes-csi/csi-test/v4/utils" + "google.golang.org/grpc" +) + +type MockCSIDriverServers struct { + Controller *MockControllerServer + Identity *MockIdentityServer + Node *MockNodeServer +} + +type MockCSIDriver struct { + CSIDriver + conn *grpc.ClientConn +} + +func NewMockCSIDriver(servers *MockCSIDriverServers) *MockCSIDriver { + return &MockCSIDriver{ + CSIDriver: CSIDriver{ + servers: &CSIDriverServers{ + Controller: servers.Controller, + Node: servers.Node, + Identity: servers.Identity, + }, + }, + } +} + +// StartOnAddress starts a new gRPC server listening on given address. +func (m *MockCSIDriver) StartOnAddress(network, address string) error { + l, err := net.Listen(network, address) + if err != nil { + return err + } + + if err := m.CSIDriver.Start(l); err != nil { + l.Close() + return err + } + + return nil +} + +// Start starts a new gRPC server listening on a random TCP loopback port. +func (m *MockCSIDriver) Start() error { + // Listen on a port assigned by the net package + return m.StartOnAddress("tcp", "127.0.0.1:0") +} + +func (m *MockCSIDriver) Nexus() (*grpc.ClientConn, error) { + // Start server + err := m.Start() + if err != nil { + return nil, err + } + + // Create a client connection + m.conn, err = utils.Connect(m.Address(), grpc.WithInsecure()) + if err != nil { + return nil, err + } + + return m.conn, nil +} + +func (m *MockCSIDriver) Close() { + m.conn.Close() + m.server.Stop() +} diff --git a/test/e2e/storage/drivers/csi-test/mock/cache/SnapshotCache.go b/test/e2e/storage/drivers/csi-test/mock/cache/SnapshotCache.go new file mode 100644 index 00000000000..89835e11f20 --- /dev/null +++ b/test/e2e/storage/drivers/csi-test/mock/cache/SnapshotCache.go @@ -0,0 +1,89 @@ +package cache + +import ( + "strings" + "sync" + + "github.com/container-storage-interface/spec/lib/go/csi" +) + +type SnapshotCache interface { + Add(snapshot Snapshot) + + Delete(i int) + + List(ready bool) []csi.Snapshot + + FindSnapshot(k, v string) (int, Snapshot) +} + +type Snapshot struct { + Name string + Parameters map[string]string + SnapshotCSI csi.Snapshot +} + +type snapshotCache struct { + snapshotsRWL sync.RWMutex + snapshots []Snapshot +} + +func NewSnapshotCache() SnapshotCache { + return &snapshotCache{ + snapshots: make([]Snapshot, 0), + } +} + +func (snap *snapshotCache) Add(snapshot Snapshot) { + snap.snapshotsRWL.Lock() + defer snap.snapshotsRWL.Unlock() + + snap.snapshots = append(snap.snapshots, snapshot) +} + +func (snap *snapshotCache) Delete(i int) { + snap.snapshotsRWL.Lock() + defer snap.snapshotsRWL.Unlock() + + copy(snap.snapshots[i:], snap.snapshots[i+1:]) + snap.snapshots = snap.snapshots[:len(snap.snapshots)-1] +} + +func (snap *snapshotCache) List(ready bool) []csi.Snapshot { + snap.snapshotsRWL.RLock() + defer snap.snapshotsRWL.RUnlock() + + snapshots := make([]csi.Snapshot, 0) + for _, v := range snap.snapshots { + if v.SnapshotCSI.GetReadyToUse() { + snapshots = append(snapshots, v.SnapshotCSI) + } + } + + return snapshots +} + +func (snap *snapshotCache) FindSnapshot(k, v string) (int, Snapshot) { + snap.snapshotsRWL.RLock() + defer snap.snapshotsRWL.RUnlock() + + snapshotIdx := -1 + for i, vi := range snap.snapshots { + switch k { + case "id": + if strings.EqualFold(v, vi.SnapshotCSI.GetSnapshotId()) { + return i, vi + } + case "sourceVolumeId": + if strings.EqualFold(v, vi.SnapshotCSI.SourceVolumeId) { + return i, vi + } + case "name": + if vi.Name == v { + return i, vi + } + } + } + + return snapshotIdx, Snapshot{} +} diff --git a/test/e2e/storage/drivers/csi-test/mock/service/controller.go b/test/e2e/storage/drivers/csi-test/mock/service/controller.go new file mode 100644 index 00000000000..a8192fedc0e --- /dev/null +++ b/test/e2e/storage/drivers/csi-test/mock/service/controller.go @@ -0,0 +1,834 @@ +package service + +import ( + "fmt" + "math" + "path" + "reflect" + "strconv" + + "github.com/container-storage-interface/spec/lib/go/csi" + log "github.com/sirupsen/logrus" + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const ( + MaxStorageCapacity = tib + ReadOnlyKey = "readonly" +) + +func (s *service) CreateVolume( + ctx context.Context, + req *csi.CreateVolumeRequest) ( + *csi.CreateVolumeResponse, error) { + + if len(req.Name) == 0 { + return nil, status.Error(codes.InvalidArgument, "Volume Name cannot be empty") + } + if req.VolumeCapabilities == nil { + return nil, status.Error(codes.InvalidArgument, "Volume Capabilities cannot be empty") + } + if hookVal, hookMsg := s.execHook("CreateVolumeStart"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + // Check to see if the volume already exists. + if i, v := s.findVolByName(ctx, req.Name); i >= 0 { + // Requested volume name already exists, need to check if the existing volume's + // capacity is more or equal to new request's capacity. + if v.GetCapacityBytes() < req.GetCapacityRange().GetRequiredBytes() { + return nil, status.Error(codes.AlreadyExists, + fmt.Sprintf("Volume with name %s already exists", req.GetName())) + } + return &csi.CreateVolumeResponse{Volume: &v}, nil + } + + // If no capacity is specified then use 100GiB + capacity := gib100 + if cr := req.CapacityRange; cr != nil { + if rb := cr.RequiredBytes; rb > 0 { + capacity = rb + } + if lb := cr.LimitBytes; lb > 0 { + capacity = lb + } + } + // Check for maximum available capacity + if capacity >= MaxStorageCapacity { + return nil, status.Errorf(codes.OutOfRange, "Requested capacity %d exceeds maximum allowed %d", capacity, MaxStorageCapacity) + } + + var v csi.Volume + // Create volume from content source if provided. + if req.GetVolumeContentSource() != nil { + switch req.GetVolumeContentSource().GetType().(type) { + case *csi.VolumeContentSource_Snapshot: + sid := req.GetVolumeContentSource().GetSnapshot().GetSnapshotId() + // Check if the source snapshot exists. + if snapID, _ := s.snapshots.FindSnapshot("id", sid); snapID >= 0 { + v = s.newVolumeFromSnapshot(req.Name, capacity, snapID) + } else { + return nil, status.Errorf(codes.NotFound, "Requested source snapshot %s not found", sid) + } + case *csi.VolumeContentSource_Volume: + vid := req.GetVolumeContentSource().GetVolume().GetVolumeId() + // Check if the source volume exists. + if volID, _ := s.findVolNoLock("id", vid); volID >= 0 { + v = s.newVolumeFromVolume(req.Name, capacity, volID) + } else { + return nil, status.Errorf(codes.NotFound, "Requested source volume %s not found", vid) + } + } + } else { + v = s.newVolume(req.Name, capacity) + } + + // Add the created volume to the service's in-mem volume slice. + s.volsRWL.Lock() + defer s.volsRWL.Unlock() + s.vols = append(s.vols, v) + MockVolumes[v.GetVolumeId()] = Volume{ + VolumeCSI: v, + NodeID: "", + ISStaged: false, + ISPublished: false, + StageTargetPath: "", + TargetPath: "", + } + + if hookVal, hookMsg := s.execHook("CreateVolumeEnd"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + return &csi.CreateVolumeResponse{Volume: &v}, nil +} + +func (s *service) DeleteVolume( + ctx context.Context, + req *csi.DeleteVolumeRequest) ( + *csi.DeleteVolumeResponse, error) { + + s.volsRWL.Lock() + defer s.volsRWL.Unlock() + + // If the volume is not specified, return error + if len(req.VolumeId) == 0 { + return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty") + } + + if hookVal, hookMsg := s.execHook("DeleteVolumeStart"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + // If the volume does not exist then return an idempotent response. + i, _ := s.findVolNoLock("id", req.VolumeId) + if i < 0 { + return &csi.DeleteVolumeResponse{}, nil + } + + // This delete logic preserves order and prevents potential memory + // leaks. The slice's elements may not be pointers, but the structs + // themselves have fields that are. + copy(s.vols[i:], s.vols[i+1:]) + s.vols[len(s.vols)-1] = csi.Volume{} + s.vols = s.vols[:len(s.vols)-1] + log.WithField("volumeID", req.VolumeId).Debug("mock delete volume") + + if hookVal, hookMsg := s.execHook("DeleteVolumeEnd"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + return &csi.DeleteVolumeResponse{}, nil +} + +func (s *service) ControllerPublishVolume( + ctx context.Context, + req *csi.ControllerPublishVolumeRequest) ( + *csi.ControllerPublishVolumeResponse, error) { + + if s.config.DisableAttach { + return nil, status.Error(codes.Unimplemented, "ControllerPublish is not supported") + } + + if len(req.VolumeId) == 0 { + return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty") + } + if len(req.NodeId) == 0 { + return nil, status.Error(codes.InvalidArgument, "Node ID cannot be empty") + } + if req.VolumeCapability == nil { + return nil, status.Error(codes.InvalidArgument, "Volume Capabilities cannot be empty") + } + + if req.NodeId != s.nodeID { + return nil, status.Errorf(codes.NotFound, "Not matching Node ID %s to Mock Node ID %s", req.NodeId, s.nodeID) + } + + if hookVal, hookMsg := s.execHook("ControllerPublishVolumeStart"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + s.volsRWL.Lock() + defer s.volsRWL.Unlock() + + i, v := s.findVolNoLock("id", req.VolumeId) + if i < 0 { + return nil, status.Error(codes.NotFound, req.VolumeId) + } + + // devPathKey is the key in the volume's attributes that is set to a + // mock device path if the volume has been published by the controller + // to the specified node. + devPathKey := path.Join(req.NodeId, "dev") + + // Check to see if the volume is already published. + if device := v.VolumeContext[devPathKey]; device != "" { + var volRo bool + var roVal string + if ro, ok := v.VolumeContext[ReadOnlyKey]; ok { + roVal = ro + } + + if roVal == "true" { + volRo = true + } else { + volRo = false + } + + // Check if readonly flag is compatible with the publish request. + if req.GetReadonly() != volRo { + return nil, status.Error(codes.AlreadyExists, "Volume published but has incompatible readonly flag") + } + + return &csi.ControllerPublishVolumeResponse{ + PublishContext: map[string]string{ + "device": device, + "readonly": roVal, + }, + }, nil + } + + // Check attach limit before publishing only if attach limit is set. + if s.config.AttachLimit > 0 && s.getAttachCount(devPathKey) >= s.config.AttachLimit { + return nil, status.Errorf(codes.ResourceExhausted, "Cannot attach any more volumes to this node") + } + + var roVal string + if req.GetReadonly() { + roVal = "true" + } else { + roVal = "false" + } + + // Publish the volume. + device := "/dev/mock" + v.VolumeContext[devPathKey] = device + v.VolumeContext[ReadOnlyKey] = roVal + s.vols[i] = v + + if volInfo, ok := MockVolumes[req.VolumeId]; ok { + volInfo.ISControllerPublished = true + MockVolumes[req.VolumeId] = volInfo + } + + if hookVal, hookMsg := s.execHook("ControllerPublishVolumeEnd"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + return &csi.ControllerPublishVolumeResponse{ + PublishContext: map[string]string{ + "device": device, + "readonly": roVal, + }, + }, nil +} + +func (s *service) ControllerUnpublishVolume( + ctx context.Context, + req *csi.ControllerUnpublishVolumeRequest) ( + *csi.ControllerUnpublishVolumeResponse, error) { + + if s.config.DisableAttach { + return nil, status.Error(codes.Unimplemented, "ControllerPublish is not supported") + } + + if len(req.VolumeId) == 0 { + return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty") + } + nodeID := req.NodeId + if len(nodeID) == 0 { + // If node id is empty, no failure as per Spec + nodeID = s.nodeID + } + + if req.NodeId != s.nodeID { + return nil, status.Errorf(codes.NotFound, "Node ID %s does not match to expected Node ID %s", req.NodeId, s.nodeID) + } + + if hookVal, hookMsg := s.execHook("ControllerUnpublishVolumeStart"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + s.volsRWL.Lock() + defer s.volsRWL.Unlock() + + i, v := s.findVolNoLock("id", req.VolumeId) + if i < 0 { + // Not an error: a non-existent volume is not published. + // See also https://github.com/kubernetes-csi/external-attacher/pull/165 + return &csi.ControllerUnpublishVolumeResponse{}, nil + } + + // devPathKey is the key in the volume's attributes that is set to a + // mock device path if the volume has been published by the controller + // to the specified node. + devPathKey := path.Join(nodeID, "dev") + + // Check to see if the volume is already unpublished. + if v.VolumeContext[devPathKey] == "" { + return &csi.ControllerUnpublishVolumeResponse{}, nil + } + + // Unpublish the volume. + delete(v.VolumeContext, devPathKey) + delete(v.VolumeContext, ReadOnlyKey) + s.vols[i] = v + + if hookVal, hookMsg := s.execHook("ControllerUnpublishVolumeEnd"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + return &csi.ControllerUnpublishVolumeResponse{}, nil +} + +func (s *service) ValidateVolumeCapabilities( + ctx context.Context, + req *csi.ValidateVolumeCapabilitiesRequest) ( + *csi.ValidateVolumeCapabilitiesResponse, error) { + + if len(req.GetVolumeId()) == 0 { + return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty") + } + if len(req.VolumeCapabilities) == 0 { + return nil, status.Error(codes.InvalidArgument, req.VolumeId) + } + i, _ := s.findVolNoLock("id", req.VolumeId) + if i < 0 { + return nil, status.Error(codes.NotFound, req.VolumeId) + } + + if hookVal, hookMsg := s.execHook("ValidateVolumeCapabilities"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + return &csi.ValidateVolumeCapabilitiesResponse{ + Confirmed: &csi.ValidateVolumeCapabilitiesResponse_Confirmed{ + VolumeContext: req.GetVolumeContext(), + VolumeCapabilities: req.GetVolumeCapabilities(), + Parameters: req.GetParameters(), + }, + }, nil +} + +func (s *service) ControllerGetVolume( + ctx context.Context, + req *csi.ControllerGetVolumeRequest) ( + *csi.ControllerGetVolumeResponse, error) { + + if hookVal, hookMsg := s.execHook("GetVolumeStart"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + resp := &csi.ControllerGetVolumeResponse{ + Status: &csi.ControllerGetVolumeResponse_VolumeStatus{ + VolumeCondition: &csi.VolumeCondition{}, + }, + } + i, v := s.findVolByID(ctx, req.VolumeId) + if i < 0 { + resp.Status.VolumeCondition.Abnormal = true + resp.Status.VolumeCondition.Message = "volume not found" + return resp, status.Error(codes.NotFound, req.VolumeId) + } + + resp.Volume = &v + if !s.config.DisableAttach { + resp.Status.PublishedNodeIds = []string{ + s.nodeID, + } + } + + if hookVal, hookMsg := s.execHook("GetVolumeEnd"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + return resp, nil +} + +func (s *service) ListVolumes( + ctx context.Context, + req *csi.ListVolumesRequest) ( + *csi.ListVolumesResponse, error) { + + if hookVal, hookMsg := s.execHook("ListVolumesStart"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + // Copy the mock volumes into a new slice in order to avoid + // locking the service's volume slice for the duration of the + // ListVolumes RPC. + var vols []csi.Volume + func() { + s.volsRWL.RLock() + defer s.volsRWL.RUnlock() + vols = make([]csi.Volume, len(s.vols)) + copy(vols, s.vols) + }() + + var ( + ulenVols = int32(len(vols)) + maxEntries = req.MaxEntries + startingToken int32 + ) + + if v := req.StartingToken; v != "" { + i, err := strconv.ParseUint(v, 10, 32) + if err != nil { + return nil, status.Errorf( + codes.Aborted, + "startingToken=%d !< int32=%d", + startingToken, math.MaxUint32) + } + startingToken = int32(i) + } + + if startingToken > ulenVols { + return nil, status.Errorf( + codes.Aborted, + "startingToken=%d > len(vols)=%d", + startingToken, ulenVols) + } + + // Discern the number of remaining entries. + rem := ulenVols - startingToken + + // If maxEntries is 0 or greater than the number of remaining entries then + // set maxEntries to the number of remaining entries. + if maxEntries == 0 || maxEntries > rem { + maxEntries = rem + } + + var ( + i int + j = startingToken + entries = make( + []*csi.ListVolumesResponse_Entry, + maxEntries) + ) + + for i = 0; i < len(entries); i++ { + volumeStatus := &csi.ListVolumesResponse_VolumeStatus{ + VolumeCondition: &csi.VolumeCondition{}, + } + + if !s.config.DisableAttach { + volumeStatus.PublishedNodeIds = []string{ + s.nodeID, + } + } + + entries[i] = &csi.ListVolumesResponse_Entry{ + Volume: &vols[j], + Status: volumeStatus, + } + j++ + } + + var nextToken string + if n := startingToken + int32(i); n < ulenVols { + nextToken = fmt.Sprintf("%d", n) + } + + if hookVal, hookMsg := s.execHook("ListVolumesEnd"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + return &csi.ListVolumesResponse{ + Entries: entries, + NextToken: nextToken, + }, nil +} + +func (s *service) GetCapacity( + ctx context.Context, + req *csi.GetCapacityRequest) ( + *csi.GetCapacityResponse, error) { + + if hookVal, hookMsg := s.execHook("GetCapacity"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + return &csi.GetCapacityResponse{ + AvailableCapacity: MaxStorageCapacity, + }, nil +} + +func (s *service) ControllerGetCapabilities( + ctx context.Context, + req *csi.ControllerGetCapabilitiesRequest) ( + *csi.ControllerGetCapabilitiesResponse, error) { + + if hookVal, hookMsg := s.execHook("ControllerGetCapabilitiesStart"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + caps := []*csi.ControllerServiceCapability{ + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, + }, + }, + }, + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_LIST_VOLUMES, + }, + }, + }, + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_LIST_VOLUMES_PUBLISHED_NODES, + }, + }, + }, + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_GET_CAPACITY, + }, + }, + }, + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_LIST_SNAPSHOTS, + }, + }, + }, + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_CREATE_DELETE_SNAPSHOT, + }, + }, + }, + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_PUBLISH_READONLY, + }, + }, + }, + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_CLONE_VOLUME, + }, + }, + }, + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_GET_VOLUME, + }, + }, + }, + { + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_VOLUME_CONDITION, + }, + }, + }, + } + + if !s.config.DisableAttach { + caps = append(caps, &csi.ControllerServiceCapability{ + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME, + }, + }, + }) + } + + if !s.config.DisableControllerExpansion { + caps = append(caps, &csi.ControllerServiceCapability{ + Type: &csi.ControllerServiceCapability_Rpc{ + Rpc: &csi.ControllerServiceCapability_RPC{ + Type: csi.ControllerServiceCapability_RPC_EXPAND_VOLUME, + }, + }, + }) + } + + if hookVal, hookMsg := s.execHook("ControllerGetCapabilitiesEnd"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + return &csi.ControllerGetCapabilitiesResponse{ + Capabilities: caps, + }, nil +} + +func (s *service) CreateSnapshot(ctx context.Context, + req *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) { + // Check arguments + if len(req.GetName()) == 0 { + return nil, status.Error(codes.InvalidArgument, "Snapshot Name cannot be empty") + } + if len(req.GetSourceVolumeId()) == 0 { + return nil, status.Error(codes.InvalidArgument, "Snapshot SourceVolumeId cannot be empty") + } + + if hookVal, hookMsg := s.execHook("CreateSnapshotStart"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + // Check to see if the snapshot already exists. + if i, v := s.snapshots.FindSnapshot("name", req.GetName()); i >= 0 { + // Requested snapshot name already exists + if v.SnapshotCSI.GetSourceVolumeId() != req.GetSourceVolumeId() || !reflect.DeepEqual(v.Parameters, req.GetParameters()) { + return nil, status.Error(codes.AlreadyExists, + fmt.Sprintf("Snapshot with name %s already exists", req.GetName())) + } + return &csi.CreateSnapshotResponse{Snapshot: &v.SnapshotCSI}, nil + } + + // Create the snapshot and add it to the service's in-mem snapshot slice. + snapshot := s.newSnapshot(req.GetName(), req.GetSourceVolumeId(), req.GetParameters()) + s.snapshots.Add(snapshot) + + if hookVal, hookMsg := s.execHook("CreateSnapshotEnd"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + return &csi.CreateSnapshotResponse{Snapshot: &snapshot.SnapshotCSI}, nil +} + +func (s *service) DeleteSnapshot(ctx context.Context, + req *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) { + + // If the snapshot is not specified, return error + if len(req.SnapshotId) == 0 { + return nil, status.Error(codes.InvalidArgument, "Snapshot ID cannot be empty") + } + + if hookVal, hookMsg := s.execHook("DeleteSnapshotStart"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + // If the snapshot does not exist then return an idempotent response. + i, _ := s.snapshots.FindSnapshot("id", req.SnapshotId) + if i < 0 { + return &csi.DeleteSnapshotResponse{}, nil + } + + // This delete logic preserves order and prevents potential memory + // leaks. The slice's elements may not be pointers, but the structs + // themselves have fields that are. + s.snapshots.Delete(i) + log.WithField("SnapshotId", req.SnapshotId).Debug("mock delete snapshot") + + if hookVal, hookMsg := s.execHook("DeleteSnapshotEnd"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + return &csi.DeleteSnapshotResponse{}, nil +} + +func (s *service) ListSnapshots(ctx context.Context, + req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) { + + if hookVal, hookMsg := s.execHook("ListSnapshots"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + // case 1: SnapshotId is not empty, return snapshots that match the snapshot id. + if len(req.GetSnapshotId()) != 0 { + return getSnapshotById(s, req) + } + + // case 2: SourceVolumeId is not empty, return snapshots that match the source volume id. + if len(req.GetSourceVolumeId()) != 0 { + return getSnapshotByVolumeId(s, req) + } + + // case 3: no parameter is set, so we return all the snapshots. + return getAllSnapshots(s, req) +} + +func (s *service) ControllerExpandVolume( + ctx context.Context, + req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) { + if len(req.VolumeId) == 0 { + return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty") + } + + if req.CapacityRange == nil { + return nil, status.Error(codes.InvalidArgument, "Request capacity cannot be empty") + } + + if hookVal, hookMsg := s.execHook("ControllerExpandVolumeStart"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + s.volsRWL.Lock() + defer s.volsRWL.Unlock() + + i, v := s.findVolNoLock("id", req.VolumeId) + if i < 0 { + return nil, status.Error(codes.NotFound, req.VolumeId) + } + + if s.config.DisableOnlineExpansion && MockVolumes[v.GetVolumeId()].ISControllerPublished { + return nil, status.Error(codes.FailedPrecondition, "volume is published and online volume expansion is not supported") + } + + requestBytes := req.CapacityRange.RequiredBytes + + if v.CapacityBytes > requestBytes { + return nil, status.Error(codes.InvalidArgument, "cannot change volume capacity to a smaller size") + } + + resp := &csi.ControllerExpandVolumeResponse{ + CapacityBytes: requestBytes, + NodeExpansionRequired: s.config.NodeExpansionRequired, + } + + // Check to see if the volume already satisfied request size. + if v.CapacityBytes == requestBytes { + log.WithField("volumeID", v.VolumeId).Infof("Volume capacity is already %d, no need to expand", requestBytes) + return resp, nil + } + + // Update volume's capacity to the requested size. + v.CapacityBytes = requestBytes + s.vols[i] = v + + if hookVal, hookMsg := s.execHook("ControllerExpandVolumeEnd"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + return resp, nil +} + +func getSnapshotById(s *service, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) { + if len(req.GetSnapshotId()) != 0 { + i, snapshot := s.snapshots.FindSnapshot("id", req.GetSnapshotId()) + if i < 0 { + return &csi.ListSnapshotsResponse{}, nil + } + + if len(req.GetSourceVolumeId()) != 0 { + if snapshot.SnapshotCSI.GetSourceVolumeId() != req.GetSourceVolumeId() { + return &csi.ListSnapshotsResponse{}, nil + } + } + + return &csi.ListSnapshotsResponse{ + Entries: []*csi.ListSnapshotsResponse_Entry{ + { + Snapshot: &snapshot.SnapshotCSI, + }, + }, + }, nil + } + return nil, nil +} + +func getSnapshotByVolumeId(s *service, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) { + if len(req.GetSourceVolumeId()) != 0 { + i, snapshot := s.snapshots.FindSnapshot("sourceVolumeId", req.SourceVolumeId) + if i < 0 { + return &csi.ListSnapshotsResponse{}, nil + } + return &csi.ListSnapshotsResponse{ + Entries: []*csi.ListSnapshotsResponse_Entry{ + { + Snapshot: &snapshot.SnapshotCSI, + }, + }, + }, nil + } + return nil, nil +} + +func getAllSnapshots(s *service, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) { + // Copy the mock snapshots into a new slice in order to avoid + // locking the service's snapshot slice for the duration of the + // ListSnapshots RPC. + readyToUse := true + snapshots := s.snapshots.List(readyToUse) + + var ( + ulenSnapshots = int32(len(snapshots)) + maxEntries = req.MaxEntries + startingToken int32 + ) + + if v := req.StartingToken; v != "" { + i, err := strconv.ParseUint(v, 10, 32) + if err != nil { + return nil, status.Errorf( + codes.Aborted, + "startingToken=%d !< int32=%d", + startingToken, math.MaxUint32) + } + startingToken = int32(i) + } + + if startingToken > ulenSnapshots { + return nil, status.Errorf( + codes.Aborted, + "startingToken=%d > len(snapshots)=%d", + startingToken, ulenSnapshots) + } + + // Discern the number of remaining entries. + rem := ulenSnapshots - startingToken + + // If maxEntries is 0 or greater than the number of remaining entries then + // set maxEntries to the number of remaining entries. + if maxEntries == 0 || maxEntries > rem { + maxEntries = rem + } + + var ( + i int + j = startingToken + entries = make( + []*csi.ListSnapshotsResponse_Entry, + maxEntries) + ) + + for i = 0; i < len(entries); i++ { + entries[i] = &csi.ListSnapshotsResponse_Entry{ + Snapshot: &snapshots[j], + } + j++ + } + + var nextToken string + if n := startingToken + int32(i); n < ulenSnapshots { + nextToken = fmt.Sprintf("%d", n) + } + + return &csi.ListSnapshotsResponse{ + Entries: entries, + NextToken: nextToken, + }, nil +} diff --git a/test/e2e/storage/drivers/csi-test/mock/service/hooks-const.go b/test/e2e/storage/drivers/csi-test/mock/service/hooks-const.go new file mode 100644 index 00000000000..46eed6af7ca --- /dev/null +++ b/test/e2e/storage/drivers/csi-test/mock/service/hooks-const.go @@ -0,0 +1,24 @@ +package service + +// Predefinded constants for the JavaScript hooks, they must correspond to the +// error codes used by gRPC, see: +// https://github.com/grpc/grpc-go/blob/master/codes/codes.go +const ( + grpcJSCodes string = `OK = 0; + CANCELED = 1; + UNKNOWN = 2; + INVALIDARGUMENT = 3; + DEADLINEEXCEEDED = 4; + NOTFOUND = 5; + ALREADYEXISTS = 6; + PERMISSIONDENIED = 7; + RESOURCEEXHAUSTED = 8; + FAILEDPRECONDITION = 9; + ABORTED = 10; + OUTOFRANGE = 11; + UNIMPLEMENTED = 12; + INTERNAL = 13; + UNAVAILABLE = 14; + DATALOSS = 15; + UNAUTHENTICATED = 16` +) diff --git a/test/e2e/storage/drivers/csi-test/mock/service/identity.go b/test/e2e/storage/drivers/csi-test/mock/service/identity.go new file mode 100644 index 00000000000..837c8763c1c --- /dev/null +++ b/test/e2e/storage/drivers/csi-test/mock/service/identity.go @@ -0,0 +1,74 @@ +package service + +import ( + "golang.org/x/net/context" + + "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/golang/protobuf/ptypes/wrappers" +) + +func (s *service) GetPluginInfo( + ctx context.Context, + req *csi.GetPluginInfoRequest) ( + *csi.GetPluginInfoResponse, error) { + + return &csi.GetPluginInfoResponse{ + Name: s.config.DriverName, + VendorVersion: VendorVersion, + Manifest: Manifest, + }, nil +} + +func (s *service) Probe( + ctx context.Context, + req *csi.ProbeRequest) ( + *csi.ProbeResponse, error) { + + return &csi.ProbeResponse{ + Ready: &wrappers.BoolValue{Value: true}, + }, nil +} + +func (s *service) GetPluginCapabilities( + ctx context.Context, + req *csi.GetPluginCapabilitiesRequest) ( + *csi.GetPluginCapabilitiesResponse, error) { + + volExpType := csi.PluginCapability_VolumeExpansion_ONLINE + + if s.config.DisableOnlineExpansion { + volExpType = csi.PluginCapability_VolumeExpansion_OFFLINE + } + + capabilities := []*csi.PluginCapability{ + { + Type: &csi.PluginCapability_Service_{ + Service: &csi.PluginCapability_Service{ + Type: csi.PluginCapability_Service_CONTROLLER_SERVICE, + }, + }, + }, + { + Type: &csi.PluginCapability_VolumeExpansion_{ + VolumeExpansion: &csi.PluginCapability_VolumeExpansion{ + Type: volExpType, + }, + }, + }, + } + + if s.config.EnableTopology { + capabilities = append(capabilities, + &csi.PluginCapability{ + Type: &csi.PluginCapability_Service_{ + Service: &csi.PluginCapability_Service{ + Type: csi.PluginCapability_Service_VOLUME_ACCESSIBILITY_CONSTRAINTS, + }, + }, + }) + } + + return &csi.GetPluginCapabilitiesResponse{ + Capabilities: capabilities, + }, nil +} diff --git a/test/e2e/storage/drivers/csi-test/mock/service/node.go b/test/e2e/storage/drivers/csi-test/mock/service/node.go new file mode 100644 index 00000000000..7c509150181 --- /dev/null +++ b/test/e2e/storage/drivers/csi-test/mock/service/node.go @@ -0,0 +1,460 @@ +package service + +import ( + "fmt" + "os" + "path" + "strconv" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "golang.org/x/net/context" + + "github.com/container-storage-interface/spec/lib/go/csi" +) + +func (s *service) NodeStageVolume( + ctx context.Context, + req *csi.NodeStageVolumeRequest) ( + *csi.NodeStageVolumeResponse, error) { + + if hookVal, hookMsg := s.execHook("NodeStageVolumeStart"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + device, ok := req.PublishContext["device"] + if !ok { + if s.config.DisableAttach { + device = "mock device" + } else { + return nil, status.Error( + codes.InvalidArgument, + "stage volume info 'device' key required") + } + } + + if len(req.GetVolumeId()) == 0 { + return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty") + } + + if len(req.GetStagingTargetPath()) == 0 { + return nil, status.Error(codes.InvalidArgument, "Staging Target Path cannot be empty") + } + + if req.GetVolumeCapability() == nil { + return nil, status.Error(codes.InvalidArgument, "Volume Capability cannot be empty") + } + + exists, err := checkTargetExists(req.StagingTargetPath) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + if !exists { + status.Errorf(codes.Internal, "staging target path %s does not exist", req.StagingTargetPath) + } + + s.volsRWL.Lock() + defer s.volsRWL.Unlock() + + i, v := s.findVolNoLock("id", req.VolumeId) + if i < 0 { + return nil, status.Error(codes.NotFound, req.VolumeId) + } + + // nodeStgPathKey is the key in the volume's attributes that is set to a + // mock stage path if the volume has been published by the node + nodeStgPathKey := path.Join(s.nodeID, req.StagingTargetPath) + + // Check to see if the volume has already been staged. + if v.VolumeContext[nodeStgPathKey] != "" { + // TODO: Check for the capabilities to be equal. Return "ALREADY_EXISTS" + // if the capabilities don't match. + return &csi.NodeStageVolumeResponse{}, nil + } + + // Stage the volume. + v.VolumeContext[nodeStgPathKey] = device + s.vols[i] = v + + if hookVal, hookMsg := s.execHook("NodeStageVolumeEnd"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + return &csi.NodeStageVolumeResponse{}, nil +} + +func (s *service) NodeUnstageVolume( + ctx context.Context, + req *csi.NodeUnstageVolumeRequest) ( + *csi.NodeUnstageVolumeResponse, error) { + + if len(req.GetVolumeId()) == 0 { + return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty") + } + + if len(req.GetStagingTargetPath()) == 0 { + return nil, status.Error(codes.InvalidArgument, "Staging Target Path cannot be empty") + } + + if hookVal, hookMsg := s.execHook("NodeUnstageVolumeStart"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + s.volsRWL.Lock() + defer s.volsRWL.Unlock() + + i, v := s.findVolNoLock("id", req.VolumeId) + if i < 0 { + return nil, status.Error(codes.NotFound, req.VolumeId) + } + + // nodeStgPathKey is the key in the volume's attributes that is set to a + // mock stage path if the volume has been published by the node + nodeStgPathKey := path.Join(s.nodeID, req.StagingTargetPath) + + // Check to see if the volume has already been unstaged. + if v.VolumeContext[nodeStgPathKey] == "" { + return &csi.NodeUnstageVolumeResponse{}, nil + } + + // Unpublish the volume. + delete(v.VolumeContext, nodeStgPathKey) + s.vols[i] = v + + if hookVal, hookMsg := s.execHook("NodeUnstageVolumeEnd"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + return &csi.NodeUnstageVolumeResponse{}, nil +} + +func (s *service) NodePublishVolume( + ctx context.Context, + req *csi.NodePublishVolumeRequest) ( + *csi.NodePublishVolumeResponse, error) { + + if hookVal, hookMsg := s.execHook("NodePublishVolumeStart"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + ephemeralVolume := req.GetVolumeContext()["csi.storage.k8s.io/ephemeral"] == "true" + device, ok := req.PublishContext["device"] + if !ok { + if ephemeralVolume || s.config.DisableAttach { + device = "mock device" + } else { + return nil, status.Error( + codes.InvalidArgument, + "stage volume info 'device' key required") + } + } + + if len(req.GetVolumeId()) == 0 { + return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty") + } + + if len(req.GetTargetPath()) == 0 { + return nil, status.Error(codes.InvalidArgument, "Target Path cannot be empty") + } + + if req.GetVolumeCapability() == nil { + return nil, status.Error(codes.InvalidArgument, "Volume Capability cannot be empty") + } + + // May happen with old (or, at this time, even the current) Kubernetes + // although it shouldn't (https://github.com/kubernetes/kubernetes/issues/75535). + exists, err := checkTargetExists(req.TargetPath) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + if !s.config.PermissiveTargetPath && exists { + status.Errorf(codes.Internal, "target path %s does exist", req.TargetPath) + } + + s.volsRWL.Lock() + defer s.volsRWL.Unlock() + + i, v := s.findVolNoLock("id", req.VolumeId) + if i < 0 && !ephemeralVolume { + return nil, status.Error(codes.NotFound, req.VolumeId) + } + if i >= 0 && ephemeralVolume { + return nil, status.Error(codes.AlreadyExists, req.VolumeId) + } + + // nodeMntPathKey is the key in the volume's attributes that is set to a + // mock mount path if the volume has been published by the node + nodeMntPathKey := path.Join(s.nodeID, req.TargetPath) + + // Check to see if the volume has already been published. + if v.VolumeContext[nodeMntPathKey] != "" { + + // Requests marked Readonly fail due to volumes published by + // the Mock driver supporting only RW mode. + if req.Readonly { + return nil, status.Error(codes.AlreadyExists, req.VolumeId) + } + + return &csi.NodePublishVolumeResponse{}, nil + } + + // Publish the volume. + if ephemeralVolume { + MockVolumes[req.VolumeId] = Volume{ + ISEphemeral: true, + } + } else { + if req.GetTargetPath() != "" { + exists, err := checkTargetExists(req.GetTargetPath()) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + if !exists { + // If target path does not exist we need to create the directory where volume will be staged + if err = os.Mkdir(req.TargetPath, os.FileMode(0755)); err != nil { + msg := fmt.Sprintf("NodePublishVolume: could not create target dir %q: %v", req.TargetPath, err) + return nil, status.Error(codes.Internal, msg) + } + } + v.VolumeContext[nodeMntPathKey] = req.GetTargetPath() + } else { + v.VolumeContext[nodeMntPathKey] = device + } + s.vols[i] = v + } + if hookVal, hookMsg := s.execHook("NodePublishVolumeEnd"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + return &csi.NodePublishVolumeResponse{}, nil +} + +func (s *service) NodeUnpublishVolume( + ctx context.Context, + req *csi.NodeUnpublishVolumeRequest) ( + *csi.NodeUnpublishVolumeResponse, error) { + + if len(req.GetVolumeId()) == 0 { + return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty") + } + if len(req.GetTargetPath()) == 0 { + return nil, status.Error(codes.InvalidArgument, "Target Path cannot be empty") + } + if hookVal, hookMsg := s.execHook("NodeUnpublishVolumeStart"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + s.volsRWL.Lock() + defer s.volsRWL.Unlock() + + ephemeralVolume := MockVolumes[req.VolumeId].ISEphemeral + i, v := s.findVolNoLock("id", req.VolumeId) + if i < 0 && !ephemeralVolume { + return nil, status.Error(codes.NotFound, req.VolumeId) + } + + if ephemeralVolume { + delete(MockVolumes, req.VolumeId) + } else { + // nodeMntPathKey is the key in the volume's attributes that is set to a + // mock mount path if the volume has been published by the node + nodeMntPathKey := path.Join(s.nodeID, req.TargetPath) + + // Check to see if the volume has already been unpublished. + if v.VolumeContext[nodeMntPathKey] == "" { + return &csi.NodeUnpublishVolumeResponse{}, nil + } + + // Delete any created paths + err := os.RemoveAll(v.VolumeContext[nodeMntPathKey]) + if err != nil { + return nil, status.Errorf(codes.Internal, "Unable to delete previously created target directory") + } + + // Unpublish the volume. + delete(v.VolumeContext, nodeMntPathKey) + s.vols[i] = v + } + if hookVal, hookMsg := s.execHook("NodeUnpublishVolumeEnd"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + return &csi.NodeUnpublishVolumeResponse{}, nil +} + +func (s *service) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) { + if len(req.GetVolumeId()) == 0 { + return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty") + } + if len(req.GetVolumePath()) == 0 { + return nil, status.Error(codes.InvalidArgument, "Volume Path cannot be empty") + } + if hookVal, hookMsg := s.execHook("NodeExpandVolumeStart"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + s.volsRWL.Lock() + defer s.volsRWL.Unlock() + + i, v := s.findVolNoLock("id", req.VolumeId) + if i < 0 { + return nil, status.Error(codes.NotFound, req.VolumeId) + } + + // TODO: NodeExpandVolume MUST be called after successful NodeStageVolume as we has STAGE_UNSTAGE_VOLUME node capacity. + resp := &csi.NodeExpandVolumeResponse{} + var requestCapacity int64 = 0 + if req.GetCapacityRange() != nil { + requestCapacity = req.CapacityRange.GetRequiredBytes() + resp.CapacityBytes = requestCapacity + } + + // fsCapacityKey is the key in the volume's attributes that is set to the file system's size. + fsCapacityKey := path.Join(s.nodeID, req.GetVolumePath(), "size") + // Update volume's fs capacity to requested size. + if requestCapacity > 0 { + v.VolumeContext[fsCapacityKey] = strconv.FormatInt(requestCapacity, 10) + s.vols[i] = v + } + if hookVal, hookMsg := s.execHook("NodeExpandVolumeEnd"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + return resp, nil +} + +func (s *service) NodeGetCapabilities( + ctx context.Context, + req *csi.NodeGetCapabilitiesRequest) ( + *csi.NodeGetCapabilitiesResponse, error) { + + if hookVal, hookMsg := s.execHook("NodeGetCapabilities"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + capabilities := []*csi.NodeServiceCapability{ + { + Type: &csi.NodeServiceCapability_Rpc{ + Rpc: &csi.NodeServiceCapability_RPC{ + Type: csi.NodeServiceCapability_RPC_UNKNOWN, + }, + }, + }, + { + Type: &csi.NodeServiceCapability_Rpc{ + Rpc: &csi.NodeServiceCapability_RPC{ + Type: csi.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME, + }, + }, + }, + { + Type: &csi.NodeServiceCapability_Rpc{ + Rpc: &csi.NodeServiceCapability_RPC{ + Type: csi.NodeServiceCapability_RPC_GET_VOLUME_STATS, + }, + }, + }, + { + Type: &csi.NodeServiceCapability_Rpc{ + Rpc: &csi.NodeServiceCapability_RPC{ + Type: csi.NodeServiceCapability_RPC_VOLUME_CONDITION, + }, + }, + }, + } + if s.config.NodeExpansionRequired { + capabilities = append(capabilities, &csi.NodeServiceCapability{ + Type: &csi.NodeServiceCapability_Rpc{ + Rpc: &csi.NodeServiceCapability_RPC{ + Type: csi.NodeServiceCapability_RPC_EXPAND_VOLUME, + }, + }, + }) + } + + return &csi.NodeGetCapabilitiesResponse{ + Capabilities: capabilities, + }, nil +} + +func (s *service) NodeGetInfo(ctx context.Context, + req *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) { + if hookVal, hookMsg := s.execHook("NodeGetInfo"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + csiNodeResponse := &csi.NodeGetInfoResponse{ + NodeId: s.nodeID, + } + if s.config.AttachLimit > 0 { + csiNodeResponse.MaxVolumesPerNode = s.config.AttachLimit + } + if s.config.EnableTopology { + csiNodeResponse.AccessibleTopology = &csi.Topology{ + Segments: map[string]string{ + TopologyKey: TopologyValue, + }, + } + } + return csiNodeResponse, nil +} + +func (s *service) NodeGetVolumeStats(ctx context.Context, + req *csi.NodeGetVolumeStatsRequest) (*csi.NodeGetVolumeStatsResponse, error) { + + if hookVal, hookMsg := s.execHook("NodeGetVolumeStatsStart"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + resp := &csi.NodeGetVolumeStatsResponse{ + VolumeCondition: &csi.VolumeCondition{}, + } + + if len(req.GetVolumeId()) == 0 { + return nil, status.Error(codes.InvalidArgument, "Volume ID cannot be empty") + } + + if len(req.GetVolumePath()) == 0 { + return nil, status.Error(codes.InvalidArgument, "Volume Path cannot be empty") + } + + i, v := s.findVolNoLock("id", req.VolumeId) + if i < 0 { + resp.VolumeCondition.Abnormal = true + resp.VolumeCondition.Message = "Volume not found" + return resp, status.Error(codes.NotFound, req.VolumeId) + } + + nodeMntPathKey := path.Join(s.nodeID, req.VolumePath) + + _, exists := v.VolumeContext[nodeMntPathKey] + if !exists { + msg := fmt.Sprintf("volume %q doest not exist on the specified path %q", req.VolumeId, req.VolumePath) + resp.VolumeCondition.Abnormal = true + resp.VolumeCondition.Message = msg + return resp, status.Errorf(codes.NotFound, msg) + } + + if hookVal, hookMsg := s.execHook("NodeGetVolumeStatsEnd"); hookVal != codes.OK { + return nil, status.Errorf(hookVal, hookMsg) + } + + resp.Usage = []*csi.VolumeUsage{ + { + Total: v.GetCapacityBytes(), + Unit: csi.VolumeUsage_BYTES, + }, + } + + return resp, nil +} + +// checkTargetExists checks if a given path exists. +func checkTargetExists(targetPath string) (bool, error) { + _, err := os.Stat(targetPath) + switch { + case err == nil: + return true, nil + case os.IsNotExist(err): + return false, nil + default: + return false, err + } +} diff --git a/test/e2e/storage/drivers/csi-test/mock/service/service.go b/test/e2e/storage/drivers/csi-test/mock/service/service.go new file mode 100644 index 00000000000..ff54ae9e506 --- /dev/null +++ b/test/e2e/storage/drivers/csi-test/mock/service/service.go @@ -0,0 +1,293 @@ +package service + +import ( + "fmt" + "reflect" + "strings" + "sync" + "sync/atomic" + + "k8s.io/klog" + + "github.com/container-storage-interface/spec/lib/go/csi" + "github.com/kubernetes-csi/csi-test/v4/mock/cache" + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + + "github.com/golang/protobuf/ptypes" + + "github.com/robertkrimen/otto" +) + +const ( + // Name is the name of the CSI plug-in. + Name = "io.kubernetes.storage.mock" + + // VendorVersion is the version returned by GetPluginInfo. + VendorVersion = "0.3.0" + + // TopologyKey simulates a per-node topology. + TopologyKey = Name + "/node" + + // TopologyValue is the one, fixed node on which the driver runs. + TopologyValue = "some-mock-node" +) + +// Manifest is the SP's manifest. +var Manifest = map[string]string{ + "url": "https://github.com/kubernetes-csi/csi-test/mock", +} + +// JavaScript hooks to be run to perform various tests +type Hooks struct { + Globals string `yaml:"globals"` // will be executed once before all other scripts + CreateVolumeStart string `yaml:"createVolumeStart"` + CreateVolumeEnd string `yaml:"createVolumeEnd"` + DeleteVolumeStart string `yaml:"deleteVolumeStart"` + DeleteVolumeEnd string `yaml:"deleteVolumeEnd"` + ControllerPublishVolumeStart string `yaml:"controllerPublishVolumeStart"` + ControllerPublishVolumeEnd string `yaml:"controllerPublishVolumeEnd"` + ControllerUnpublishVolumeStart string `yaml:"controllerUnpublishVolumeStart"` + ControllerUnpublishVolumeEnd string `yaml:"controllerUnpublishVolumeEnd"` + ValidateVolumeCapabilities string `yaml:"validateVolumeCapabilities"` + ListVolumesStart string `yaml:"listVolumesStart"` + ListVolumesEnd string `yaml:"listVolumesEnd"` + GetCapacity string `yaml:"getCapacity"` + ControllerGetCapabilitiesStart string `yaml:"controllerGetCapabilitiesStart"` + ControllerGetCapabilitiesEnd string `yaml:"controllerGetCapabilitiesEnd"` + CreateSnapshotStart string `yaml:"createSnapshotStart"` + CreateSnapshotEnd string `yaml:"createSnapshotEnd"` + DeleteSnapshotStart string `yaml:"deleteSnapshotStart"` + DeleteSnapshotEnd string `yaml:"deleteSnapshotEnd"` + ListSnapshots string `yaml:"listSnapshots"` + ControllerExpandVolumeStart string `yaml:"controllerExpandVolumeStart"` + ControllerExpandVolumeEnd string `yaml:"controllerExpandVolumeEnd"` + NodeStageVolumeStart string `yaml:"nodeStageVolumeStart"` + NodeStageVolumeEnd string `yaml:"nodeStageVolumeEnd"` + NodeUnstageVolumeStart string `yaml:"nodeUnstageVolumeStart"` + NodeUnstageVolumeEnd string `yaml:"nodeUnstageVolumeEnd"` + NodePublishVolumeStart string `yaml:"nodePublishVolumeStart"` + NodePublishVolumeEnd string `yaml:"nodePublishVolumeEnd"` + NodeUnpublishVolumeStart string `yaml:"nodeUnpublishVolumeStart"` + NodeUnpublishVolumeEnd string `yaml:"nodeUnpublishVolumeEnd"` + NodeExpandVolumeStart string `yaml:"nodeExpandVolumeStart"` + NodeExpandVolumeEnd string `yaml:"nodeExpandVolumeEnd"` + NodeGetCapabilities string `yaml:"nodeGetCapabilities"` + NodeGetInfo string `yaml:"nodeGetInfo"` + NodeGetVolumeStatsStart string `yaml:"nodeGetVolumeStatsStart"` + NodeGetVolumeStatsEnd string `yaml:"nodeGetVolumeStatsEnd"` +} + +type Config struct { + DisableAttach bool + DriverName string + AttachLimit int64 + NodeExpansionRequired bool + DisableControllerExpansion bool + DisableOnlineExpansion bool + PermissiveTargetPath bool + EnableTopology bool + ExecHooks *Hooks +} + +// Service is the CSI Mock service provider. +type Service interface { + csi.ControllerServer + csi.IdentityServer + csi.NodeServer +} + +type service struct { + sync.Mutex + nodeID string + vols []csi.Volume + volsRWL sync.RWMutex + volsNID uint64 + snapshots cache.SnapshotCache + snapshotsNID uint64 + config Config + hooksVm *otto.Otto +} + +type Volume struct { + VolumeCSI csi.Volume + NodeID string + ISStaged bool + ISPublished bool + ISEphemeral bool + ISControllerPublished bool + StageTargetPath string + TargetPath string +} + +var MockVolumes map[string]Volume + +// New returns a new Service. +func New(config Config) Service { + s := &service{ + nodeID: config.DriverName, + config: config, + } + if config.ExecHooks != nil { + s.hooksVm = otto.New() + s.hooksVm.Run(grpcJSCodes) // set global variables with gRPC error codes + _, err := s.hooksVm.Run(s.config.ExecHooks.Globals) + if err != nil { + klog.Exitf("Error encountered in the global exec hook: %v. Exiting\n", err) + } + } + s.snapshots = cache.NewSnapshotCache() + s.vols = []csi.Volume{ + s.newVolume("Mock Volume 1", gib100), + s.newVolume("Mock Volume 2", gib100), + s.newVolume("Mock Volume 3", gib100), + } + MockVolumes = map[string]Volume{} + + s.snapshots.Add(s.newSnapshot("Mock Snapshot 1", "1", map[string]string{"Description": "snapshot 1"})) + s.snapshots.Add(s.newSnapshot("Mock Snapshot 2", "2", map[string]string{"Description": "snapshot 2"})) + s.snapshots.Add(s.newSnapshot("Mock Snapshot 3", "3", map[string]string{"Description": "snapshot 3"})) + + return s +} + +const ( + kib int64 = 1024 + mib int64 = kib * 1024 + gib int64 = mib * 1024 + gib100 int64 = gib * 100 + tib int64 = gib * 1024 + tib100 int64 = tib * 100 +) + +func (s *service) newVolume(name string, capcity int64) csi.Volume { + vol := csi.Volume{ + VolumeId: fmt.Sprintf("%d", atomic.AddUint64(&s.volsNID, 1)), + VolumeContext: map[string]string{"name": name}, + CapacityBytes: capcity, + } + s.setTopology(&vol) + return vol +} + +func (s *service) newVolumeFromSnapshot(name string, capacity int64, snapshotID int) csi.Volume { + vol := s.newVolume(name, capacity) + vol.ContentSource = &csi.VolumeContentSource{ + Type: &csi.VolumeContentSource_Snapshot{ + Snapshot: &csi.VolumeContentSource_SnapshotSource{ + SnapshotId: fmt.Sprintf("%d", snapshotID), + }, + }, + } + s.setTopology(&vol) + return vol +} + +func (s *service) newVolumeFromVolume(name string, capacity int64, volumeID int) csi.Volume { + vol := s.newVolume(name, capacity) + vol.ContentSource = &csi.VolumeContentSource{ + Type: &csi.VolumeContentSource_Volume{ + Volume: &csi.VolumeContentSource_VolumeSource{ + VolumeId: fmt.Sprintf("%d", volumeID), + }, + }, + } + s.setTopology(&vol) + return vol +} + +func (s *service) setTopology(vol *csi.Volume) { + if s.config.EnableTopology { + vol.AccessibleTopology = []*csi.Topology{ + &csi.Topology{ + Segments: map[string]string{ + TopologyKey: TopologyValue, + }, + }, + } + } +} + +func (s *service) findVol(k, v string) (volIdx int, volInfo csi.Volume) { + s.volsRWL.RLock() + defer s.volsRWL.RUnlock() + return s.findVolNoLock(k, v) +} + +func (s *service) findVolNoLock(k, v string) (volIdx int, volInfo csi.Volume) { + volIdx = -1 + + for i, vi := range s.vols { + switch k { + case "id": + if strings.EqualFold(v, vi.GetVolumeId()) { + return i, vi + } + case "name": + if n, ok := vi.VolumeContext["name"]; ok && strings.EqualFold(v, n) { + return i, vi + } + } + } + + return +} + +func (s *service) findVolByName( + ctx context.Context, name string) (int, csi.Volume) { + + return s.findVol("name", name) +} + +func (s *service) findVolByID( + ctx context.Context, id string) (int, csi.Volume) { + + return s.findVol("id", id) +} + +func (s *service) newSnapshot(name, sourceVolumeId string, parameters map[string]string) cache.Snapshot { + + ptime := ptypes.TimestampNow() + return cache.Snapshot{ + Name: name, + Parameters: parameters, + SnapshotCSI: csi.Snapshot{ + SnapshotId: fmt.Sprintf("%d", atomic.AddUint64(&s.snapshotsNID, 1)), + CreationTime: ptime, + SourceVolumeId: sourceVolumeId, + ReadyToUse: true, + }, + } +} + +// getAttachCount returns the number of attached volumes on the node. +func (s *service) getAttachCount(devPathKey string) int64 { + var count int64 + for _, v := range s.vols { + if device := v.VolumeContext[devPathKey]; device != "" { + count++ + } + } + return count +} + +func (s *service) execHook(hookName string) (codes.Code, string) { + if s.hooksVm != nil { + script := reflect.ValueOf(*s.config.ExecHooks).FieldByName(hookName).String() + if len(script) > 0 { + result, err := s.hooksVm.Run(script) + if err != nil { + klog.Exitf("Exec hook %s error: %v; exiting\n", hookName, err) + } + rv, err := result.ToInteger() + if err == nil { + // Function returned an integer, use it + return codes.Code(rv), fmt.Sprintf("Exec hook %s returned non-OK code", hookName) + } else { + // Function returned non-integer data type, discard it + return codes.OK, "" + } + } + } + return codes.OK, "" +}