diff --git a/pkg/kubelet/cm/dra/claiminfo.go b/pkg/kubelet/cm/dra/claiminfo.go index 9c70b487085..5602707a0a4 100644 --- a/pkg/kubelet/cm/dra/claiminfo.go +++ b/pkg/kubelet/cm/dra/claiminfo.go @@ -102,6 +102,10 @@ func (info *ClaimInfo) setCDIDevices(pluginName string, cdiDevices []string) err info.CDIDevices = make(map[string][]string) } + if info.annotations == nil { + info.annotations = make(map[string][]kubecontainer.Annotation) + } + info.CDIDevices[pluginName] = cdiDevices info.annotations[pluginName] = annotations diff --git a/pkg/kubelet/cm/dra/claiminfo_test.go b/pkg/kubelet/cm/dra/claiminfo_test.go new file mode 100644 index 00000000000..58652cc605c --- /dev/null +++ b/pkg/kubelet/cm/dra/claiminfo_test.go @@ -0,0 +1,894 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dra + +import ( + "errors" + "fmt" + "path" + "reflect" + "sort" + "testing" + + "github.com/stretchr/testify/assert" + resourcev1alpha2 "k8s.io/api/resource/v1alpha2" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/kubernetes/pkg/kubelet/cm/dra/state" + kubecontainer "k8s.io/kubernetes/pkg/kubelet/container" +) + +// ClaimInfo test cases + +func TestNewClaimInfoFromClaim(t *testing.T) { + namespace := "test-namespace" + className := "test-class" + driverName := "test-plugin" + claimUID := types.UID("claim-uid") + claimName := "test-claim" + + for _, test := range []struct { + description string + claim *resourcev1alpha2.ResourceClaim + expectedResult *ClaimInfo + }{ + { + description: "successfully created object", + claim: &resourcev1alpha2.ResourceClaim{ + ObjectMeta: metav1.ObjectMeta{ + UID: claimUID, + Name: claimName, + Namespace: namespace, + }, + Status: resourcev1alpha2.ResourceClaimStatus{ + DriverName: driverName, + Allocation: &resourcev1alpha2.AllocationResult{ + ResourceHandles: []resourcev1alpha2.ResourceHandle{}, + }, + }, + Spec: resourcev1alpha2.ResourceClaimSpec{ + ResourceClassName: className, + }, + }, + expectedResult: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + DriverName: driverName, + ClassName: className, + ClaimUID: claimUID, + ClaimName: claimName, + Namespace: claimName, + PodUIDs: sets.New[string](), + ResourceHandles: []resourcev1alpha2.ResourceHandle{ + {}, + }, + CDIDevices: make(map[string][]string), + }, + }, + }, + { + description: "successfully created object with empty allocation", + claim: &resourcev1alpha2.ResourceClaim{ + ObjectMeta: metav1.ObjectMeta{ + UID: claimUID, + Name: claimName, + Namespace: namespace, + }, + Status: resourcev1alpha2.ResourceClaimStatus{ + DriverName: driverName, + Allocation: &resourcev1alpha2.AllocationResult{}, + }, + Spec: resourcev1alpha2.ResourceClaimSpec{ + ResourceClassName: className, + }, + }, + expectedResult: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + DriverName: driverName, + ClassName: className, + ClaimUID: claimUID, + ClaimName: claimName, + Namespace: claimName, + PodUIDs: sets.New[string](), + ResourceHandles: []resourcev1alpha2.ResourceHandle{ + {}, + }, + CDIDevices: make(map[string][]string), + }, + }, + }, + } { + t.Run(test.description, func(t *testing.T) { + result := newClaimInfoFromClaim(test.claim) + if reflect.DeepEqual(result, test.expectedResult) { + t.Errorf("Expected %v, but got %v", test.expectedResult, result) + } + }) + } +} + +func TestNewClaimInfoFromState(t *testing.T) { + for _, test := range []struct { + description string + state *state.ClaimInfoState + expectedResult *ClaimInfo + }{ + { + description: "successfully created object", + state: &state.ClaimInfoState{ + DriverName: "test-driver", + ClassName: "test-class", + ClaimUID: "test-uid", + ClaimName: "test-claim", + Namespace: "test-namespace", + PodUIDs: sets.New[string]("test-pod-uid"), + ResourceHandles: []resourcev1alpha2.ResourceHandle{}, + CDIDevices: map[string][]string{}, + }, + }, + } { + t.Run(test.description, func(t *testing.T) { + result := newClaimInfoFromState(test.state) + if reflect.DeepEqual(result, test.expectedResult) { + t.Errorf("Expected %v, but got %v", test.expectedResult, result) + } + }) + } +} + +func TestClaimInfoSetCDIDevices(t *testing.T) { + claimUID := types.UID("claim-uid") + pluginName := "test-plugin" + device := "vendor.com/device=device1" + annotationName := fmt.Sprintf("cdi.k8s.io/%s_%s", pluginName, claimUID) + for _, test := range []struct { + description string + claimInfo *ClaimInfo + devices []string + expectedCDIDevices map[string][]string + expectedAnnotations map[string][]kubecontainer.Annotation + wantErr bool + }{ + { + description: "successfully add one device", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + DriverName: pluginName, + ClaimUID: claimUID, + }, + }, + devices: []string{device}, + expectedCDIDevices: map[string][]string{ + pluginName: {device}, + }, + expectedAnnotations: map[string][]kubecontainer.Annotation{ + pluginName: { + { + Name: annotationName, + Value: device, + }, + }, + }, + }, + { + description: "empty list of devices", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + DriverName: pluginName, + ClaimUID: claimUID, + }, + }, + devices: []string{}, + expectedCDIDevices: map[string][]string{pluginName: {}}, + expectedAnnotations: map[string][]kubecontainer.Annotation{pluginName: nil}, + }, + { + description: "incorrect device format", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + DriverName: pluginName, + ClaimUID: claimUID, + }, + }, + devices: []string{"incorrect"}, + wantErr: true, + }, + } { + t.Run(test.description, func(t *testing.T) { + err := test.claimInfo.setCDIDevices(pluginName, test.devices) + if test.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, test.expectedCDIDevices, test.claimInfo.CDIDevices) + assert.Equal(t, test.expectedAnnotations, test.claimInfo.annotations) + }) + } +} + +func TestClaimInfoAnnotationsAsList(t *testing.T) { + for _, test := range []struct { + description string + claimInfo *ClaimInfo + expectedResult []kubecontainer.Annotation + }{ + { + description: "empty annotations", + claimInfo: &ClaimInfo{ + annotations: map[string][]kubecontainer.Annotation{}, + }, + }, + { + description: "nil annotations", + claimInfo: &ClaimInfo{}, + }, + { + description: "valid annotations", + claimInfo: &ClaimInfo{ + annotations: map[string][]kubecontainer.Annotation{ + "test-plugin1": { + { + Name: "cdi.k8s.io/test-plugin1_claim-uid1", + Value: "vendor.com/device=device1", + }, + { + Name: "cdi.k8s.io/test-plugin1_claim-uid2", + Value: "vendor.com/device=device2", + }, + }, + "test-plugin2": { + { + Name: "cdi.k8s.io/test-plugin2_claim-uid1", + Value: "vendor.com/device=device1", + }, + { + Name: "cdi.k8s.io/test-plugin2_claim-uid2", + Value: "vendor.com/device=device2", + }, + }, + }, + }, + expectedResult: []kubecontainer.Annotation{ + { + Name: "cdi.k8s.io/test-plugin1_claim-uid1", + Value: "vendor.com/device=device1", + }, + { + Name: "cdi.k8s.io/test-plugin1_claim-uid2", + Value: "vendor.com/device=device2", + }, + { + Name: "cdi.k8s.io/test-plugin2_claim-uid1", + Value: "vendor.com/device=device1", + }, + { + Name: "cdi.k8s.io/test-plugin2_claim-uid2", + Value: "vendor.com/device=device2", + }, + }, + }, + } { + t.Run(test.description, func(t *testing.T) { + result := test.claimInfo.annotationsAsList() + sort.Slice(result, func(i, j int) bool { + return result[i].Name < result[j].Name + }) + assert.Equal(t, test.expectedResult, result) + }) + } +} + +func TestClaimInfoCDIdevicesAsList(t *testing.T) { + for _, test := range []struct { + description string + claimInfo *ClaimInfo + expectedResult []kubecontainer.CDIDevice + }{ + { + description: "empty CDI devices", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + CDIDevices: map[string][]string{}, + }, + }, + }, + { + description: "nil CDI devices", + claimInfo: &ClaimInfo{}, + }, + { + description: "valid CDI devices", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + CDIDevices: map[string][]string{ + "test-plugin1": { + "vendor.com/device=device1", + "vendor.com/device=device2", + }, + "test-plugin2": { + "vendor.com/device=device1", + "vendor.com/device=device2", + }, + }, + }, + }, + expectedResult: []kubecontainer.CDIDevice{ + { + Name: "vendor.com/device=device1", + }, + { + Name: "vendor.com/device=device1", + }, + { + Name: "vendor.com/device=device2", + }, + { + Name: "vendor.com/device=device2", + }, + }, + }, + } { + t.Run(test.description, func(t *testing.T) { + result := test.claimInfo.cdiDevicesAsList() + sort.Slice(result, func(i, j int) bool { + return result[i].Name < result[j].Name + }) + assert.Equal(t, test.expectedResult, result) + }) + } +} +func TestClaimInfoAddPodReference(t *testing.T) { + podUID := types.UID("pod-uid") + for _, test := range []struct { + description string + claimInfo *ClaimInfo + expectedLen int + }{ + { + description: "successfully add pod reference", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + PodUIDs: sets.New[string](), + }, + }, + expectedLen: 1, + }, + { + description: "duplicate pod reference", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + PodUIDs: sets.New[string](string(podUID)), + }, + }, + expectedLen: 1, + }, + { + description: "duplicate pod reference", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + PodUIDs: sets.New[string]("pod-uid1"), + }, + }, + expectedLen: 2, + }, + } { + t.Run(test.description, func(t *testing.T) { + test.claimInfo.addPodReference(podUID) + assert.True(t, test.claimInfo.hasPodReference(podUID)) + assert.Len(t, test.claimInfo.PodUIDs, test.expectedLen) + }) + } +} + +func TestClaimInfoHasPodReference(t *testing.T) { + podUID := types.UID("pod-uid") + for _, test := range []struct { + description string + claimInfo *ClaimInfo + expectedResult bool + }{ + { + description: "claim doesn't reference pod", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + PodUIDs: sets.New[string](), + }, + }, + }, + { + description: "claim references pod", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + PodUIDs: sets.New[string](string(podUID)), + }, + }, + expectedResult: true, + }, + { + description: "empty claim info", + claimInfo: &ClaimInfo{}, + }, + } { + t.Run(test.description, func(t *testing.T) { + assert.Equal(t, test.claimInfo.hasPodReference(podUID), test.expectedResult) + }) + } +} + +func TestClaimInfoDeletePodReference(t *testing.T) { + podUID := types.UID("pod-uid") + for _, test := range []struct { + description string + claimInfo *ClaimInfo + }{ + { + description: "claim doesn't reference pod", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + PodUIDs: sets.New[string](), + }, + }, + }, + { + description: "claim references pod", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + PodUIDs: sets.New[string](string(podUID)), + }, + }, + }, + { + description: "empty claim info", + claimInfo: &ClaimInfo{}, + }, + } { + t.Run(test.description, func(t *testing.T) { + test.claimInfo.deletePodReference(podUID) + assert.False(t, test.claimInfo.hasPodReference(podUID)) + }) + } +} + +func TestClaimInfoSetPrepared(t *testing.T) { + for _, test := range []struct { + description string + claimInfo *ClaimInfo + }{ + { + description: "claim info is not prepared", + claimInfo: &ClaimInfo{ + prepared: false, + }, + }, + { + description: "claim info is prepared", + claimInfo: &ClaimInfo{ + prepared: true, + }, + }, + { + description: "empty claim info", + claimInfo: &ClaimInfo{}, + }, + } { + t.Run(test.description, func(t *testing.T) { + test.claimInfo.setPrepared() + assert.Equal(t, test.claimInfo.isPrepared(), true) + }) + } +} + +func TestClaimInfoIsPrepared(t *testing.T) { + for _, test := range []struct { + description string + claimInfo *ClaimInfo + expectedResult bool + }{ + { + description: "claim info is not prepared", + claimInfo: &ClaimInfo{ + prepared: false, + }, + expectedResult: false, + }, + { + description: "claim info is prepared", + claimInfo: &ClaimInfo{ + prepared: true, + }, + expectedResult: true, + }, + { + description: "empty claim info", + claimInfo: &ClaimInfo{}, + expectedResult: false, + }, + } { + t.Run(test.description, func(t *testing.T) { + assert.Equal(t, test.claimInfo.isPrepared(), test.expectedResult) + }) + } +} + +// claimInfoCache test cases +func TestNewClaimInfoCache(t *testing.T) { + for _, test := range []struct { + description string + stateDir string + checkpointName string + wantErr bool + }{ + { + description: "successfully created cache", + stateDir: t.TempDir(), + checkpointName: "test-checkpoint", + }, + { + description: "empty parameters", + wantErr: true, + }, + { + description: "empty checkpoint name", + stateDir: t.TempDir(), + wantErr: true, + }, + { + description: "incorrect checkpoint name", + stateDir: path.Join(t.TempDir(), "incorrect checkpoint"), + wantErr: true, + }, + } { + t.Run(test.description, func(t *testing.T) { + result, err := newClaimInfoCache(test.stateDir, test.checkpointName) + if test.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.NotNil(t, result) + }) + } +} + +func TestClaimInfoCacheWithLock(t *testing.T) { + for _, test := range []struct { + description string + funcGen func(cache *claimInfoCache) func() error + wantErr bool + }{ + { + description: "cache is locked inside a function", + funcGen: func(cache *claimInfoCache) func() error { + return func() error { + if cache.RWMutex.TryLock() { + return errors.New("Lock succeeded") + } + return nil + } + }, + }, + { + description: "cache is Rlocked inside a function", + funcGen: func(cache *claimInfoCache) func() error { + return func() error { + if cache.RWMutex.TryRLock() { + return errors.New("RLock succeeded") + } + return nil + } + }, + }, + { + description: "successfully called function", + funcGen: func(cache *claimInfoCache) func() error { + return func() error { + return nil + } + }, + }, + { + description: "erroring function", + funcGen: func(cache *claimInfoCache) func() error { + return func() error { + return errors.New("test error") + } + }, + wantErr: true, + }, + } { + t.Run(test.description, func(t *testing.T) { + cache, err := newClaimInfoCache(t.TempDir(), "test-checkpoint") + assert.NoError(t, err) + assert.NotNil(t, cache) + err = cache.withLock(test.funcGen(cache)) + if test.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} + +func TestClaimInfoCacheWithRLock(t *testing.T) { + for _, test := range []struct { + description string + funcGen func(cache *claimInfoCache) func() error + wantErr bool + }{ + { + description: "RLock-ed cache allows another RLock", + funcGen: func(cache *claimInfoCache) func() error { + return func() error { + if !cache.RWMutex.TryRLock() { + return errors.New("RLock failed") + } + return nil + } + }, + }, + { + description: "cache is locked inside a function", + funcGen: func(cache *claimInfoCache) func() error { + return func() error { + if cache.RWMutex.TryLock() { + return errors.New("Lock succeeded") + } + return nil + } + }, + }, + { + description: "successfully called function", + funcGen: func(cache *claimInfoCache) func() error { + return func() error { + return nil + } + }, + }, + { + description: "erroring function", + funcGen: func(cache *claimInfoCache) func() error { + return func() error { + return errors.New("test error") + } + }, + wantErr: true, + }, + } { + t.Run(test.description, func(t *testing.T) { + cache, err := newClaimInfoCache(t.TempDir(), "test-checkpoint") + assert.NoError(t, err) + assert.NotNil(t, cache) + err = cache.withRLock(test.funcGen(cache)) + if test.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} + +func TestClaimInfoCacheAdd(t *testing.T) { + for _, test := range []struct { + description string + claimInfo *ClaimInfo + }{ + { + description: "claimInfo successfully added", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + ClaimName: "test-claim", + Namespace: "test-namespace", + }, + }, + }, + } { + t.Run(test.description, func(t *testing.T) { + cache, err := newClaimInfoCache(t.TempDir(), "test-checkpoint") + assert.NoError(t, err) + assert.NotNil(t, cache) + cache.add(test.claimInfo) + assert.True(t, cache.contains(test.claimInfo.ClaimName, test.claimInfo.Namespace)) + }) + } +} + +func TestClaimInfoCacheContains(t *testing.T) { + claimName := "test-claim" + namespace := "test-namespace" + for _, test := range []struct { + description string + claimInfo *ClaimInfo + claimInfoCache *claimInfoCache + expectedResult bool + }{ + { + description: "cache hit", + claimInfoCache: &claimInfoCache{ + claimInfo: map[string]*ClaimInfo{ + namespace + "/" + claimName: { + ClaimInfoState: state.ClaimInfoState{ + ClaimName: claimName, + Namespace: namespace, + }, + }, + }, + }, + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + ClaimName: claimName, + Namespace: namespace, + }, + }, + expectedResult: true, + }, + { + description: "cache miss", + claimInfoCache: &claimInfoCache{}, + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + ClaimName: claimName, + Namespace: namespace, + }, + }, + }, + { + description: "cache miss: empty cache and empty claim info", + claimInfoCache: &claimInfoCache{}, + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{}, + }, + }, + } { + t.Run(test.description, func(t *testing.T) { + assert.Equal(t, test.expectedResult, test.claimInfoCache.contains(test.claimInfo.ClaimName, test.claimInfo.Namespace)) + }) + } +} + +func TestClaimInfoCacheGet(t *testing.T) { + claimName := "test-claim" + namespace := "test-namespace" + for _, test := range []struct { + description string + claimInfoCache *claimInfoCache + expectedNil bool + expectedExists bool + }{ + { + description: "cache hit", + claimInfoCache: &claimInfoCache{ + claimInfo: map[string]*ClaimInfo{ + namespace + "/" + claimName: { + ClaimInfoState: state.ClaimInfoState{ + ClaimName: claimName, + Namespace: namespace, + }, + }, + }, + }, + expectedExists: true, + }, + { + description: "cache miss", + claimInfoCache: &claimInfoCache{}, + expectedNil: true, + }, + } { + t.Run(test.description, func(t *testing.T) { + result, exists := test.claimInfoCache.get(claimName, namespace) + assert.Equal(t, test.expectedExists, exists) + assert.Equal(t, test.expectedNil, result == nil) + }) + } +} + +func TestClaimInfoCacheDelete(t *testing.T) { + claimName := "test-claim" + namespace := "test-namespace" + for _, test := range []struct { + description string + claimInfoCache *claimInfoCache + }{ + { + description: "item in cache", + claimInfoCache: &claimInfoCache{ + claimInfo: map[string]*ClaimInfo{ + claimName + namespace: { + ClaimInfoState: state.ClaimInfoState{ + ClaimName: claimName, + Namespace: namespace, + }, + }, + }, + }, + }, + { + description: "item not in cache", + claimInfoCache: &claimInfoCache{}, + }, + } { + t.Run(test.description, func(t *testing.T) { + test.claimInfoCache.delete(claimName, namespace) + assert.False(t, test.claimInfoCache.contains(claimName, namespace)) + }) + } +} + +func TestClaimInfoCacheHasPodReference(t *testing.T) { + claimName := "test-claim" + namespace := "test-namespace" + uid := types.UID("test-uid") + for _, test := range []struct { + description string + claimInfoCache *claimInfoCache + expectedResult bool + }{ + { + description: "uid is referenced", + claimInfoCache: &claimInfoCache{ + claimInfo: map[string]*ClaimInfo{ + claimName + namespace: { + ClaimInfoState: state.ClaimInfoState{ + ClaimName: claimName, + Namespace: namespace, + PodUIDs: sets.New[string](string(uid)), + }, + }, + }, + }, + expectedResult: true, + }, + { + description: "uid is not referenced", + claimInfoCache: &claimInfoCache{}, + }, + } { + t.Run(test.description, func(t *testing.T) { + assert.Equal(t, test.expectedResult, test.claimInfoCache.hasPodReference(uid)) + }) + } +} + +func TestSyncToCheckpoint(t *testing.T) { + for _, test := range []struct { + description string + stateDir string + checkpointName string + wantErr bool + }{ + { + description: "successfully checkpointed cache", + stateDir: t.TempDir(), + checkpointName: "test-checkpoint", + }, + } { + t.Run(test.description, func(t *testing.T) { + cache, err := newClaimInfoCache(test.stateDir, test.checkpointName) + assert.NoError(t, err) + err = cache.syncToCheckpoint() + if test.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +}