diff --git a/pkg/kubelet/cm/container_manager_linux.go b/pkg/kubelet/cm/container_manager_linux.go index 4ed1aa7b591..fb666a8333f 100644 --- a/pkg/kubelet/cm/container_manager_linux.go +++ b/pkg/kubelet/cm/container_manager_linux.go @@ -305,7 +305,7 @@ func NewContainerManager(mountUtil mount.Interface, cadvisorInterface cadvisor.I } cm.topologyManager.AddHintProvider(cm.deviceManager) - // initialize DRA manager + // Initialize DRA manager if utilfeature.DefaultFeatureGate.Enabled(kubefeatures.DynamicResourceAllocation) { klog.InfoS("Creating Dynamic Resource Allocation (DRA) manager") cm.draManager, err = dra.NewManagerImpl(kubeClient, nodeConfig.KubeletRootDir, nodeConfig.NodeName) @@ -564,6 +564,14 @@ func (cm *containerManagerImpl) Start(node *v1.Node, containerMap, containerRunningSet := buildContainerMapAndRunningSetFromRuntime(ctx, runtimeService) + // Initialize DRA manager + if utilfeature.DefaultFeatureGate.Enabled(kubefeatures.DynamicResourceAllocation) { + err := cm.draManager.Start(dra.ActivePodsFunc(activePods), sourcesReady) + if err != nil { + return fmt.Errorf("start dra manager error: %w", err) + } + } + // Initialize CPU manager err := cm.cpuManager.Start(cpumanager.ActivePodsFunc(activePods), sourcesReady, podStatusProvider, runtimeService, containerMap) if err != nil { @@ -959,7 +967,6 @@ func (cm *containerManagerImpl) GetDynamicResources(pod *v1.Pod, container *v1.C } for _, containerClaimInfo := range containerClaimInfos { var claimResources []*podresourcesapi.ClaimResource - containerClaimInfo.RLock() // TODO: Currently we maintain a list of ClaimResources, each of which contains // a set of CDIDevices from a different kubelet plugin. In the future we may want to // include the name of the kubelet plugin and/or other types of resources that are @@ -971,7 +978,6 @@ func (cm *containerManagerImpl) GetDynamicResources(pod *v1.Pod, container *v1.C } claimResources = append(claimResources, &podresourcesapi.ClaimResource{CDIDevices: cdiDevices}) } - containerClaimInfo.RUnlock() containerDynamicResource := podresourcesapi.DynamicResource{ ClassName: containerClaimInfo.ClassName, ClaimName: containerClaimInfo.ClaimName, diff --git a/pkg/kubelet/cm/dra/claiminfo.go b/pkg/kubelet/cm/dra/claiminfo.go index d369b8d3e33..5602707a0a4 100644 --- a/pkg/kubelet/cm/dra/claiminfo.go +++ b/pkg/kubelet/cm/dra/claiminfo.go @@ -30,8 +30,8 @@ import ( // ClaimInfo holds information required // to prepare and unprepare a resource claim. +// +k8s:deepcopy-gen=true type ClaimInfo struct { - sync.RWMutex state.ClaimInfoState // annotations is a mapping of container annotations per DRA plugin associated with // a prepared resource @@ -39,24 +39,57 @@ type ClaimInfo struct { prepared bool } -func (info *ClaimInfo) addPodReference(podUID types.UID) { - info.Lock() - defer info.Unlock() - - info.PodUIDs.Insert(string(podUID)) +// claimInfoCache is a cache of processed resource claims keyed by namespace/claimname. +type claimInfoCache struct { + sync.RWMutex + state state.CheckpointState + claimInfo map[string]*ClaimInfo } -func (info *ClaimInfo) deletePodReference(podUID types.UID) { - info.Lock() - defer info.Unlock() - - info.PodUIDs.Delete(string(podUID)) +// newClaimInfoFromClaim creates a new claim info from a resource claim. +func newClaimInfoFromClaim(claim *resourcev1alpha2.ResourceClaim) *ClaimInfo { + // Grab the allocation.resourceHandles. If there are no + // allocation.resourceHandles, create a single resourceHandle with no + // content. This will trigger processing of this claim by a single + // kubelet plugin whose name matches resourceClaim.Status.DriverName. + resourceHandles := claim.Status.Allocation.ResourceHandles + if len(resourceHandles) == 0 { + resourceHandles = make([]resourcev1alpha2.ResourceHandle, 1) + } + claimInfoState := state.ClaimInfoState{ + DriverName: claim.Status.DriverName, + ClassName: claim.Spec.ResourceClassName, + ClaimUID: claim.UID, + ClaimName: claim.Name, + Namespace: claim.Namespace, + PodUIDs: sets.New[string](), + ResourceHandles: resourceHandles, + CDIDevices: make(map[string][]string), + } + info := &ClaimInfo{ + ClaimInfoState: claimInfoState, + annotations: make(map[string][]kubecontainer.Annotation), + prepared: false, + } + return info } -func (info *ClaimInfo) addCDIDevices(pluginName string, cdiDevices []string) error { - info.Lock() - defer info.Unlock() +// newClaimInfoFromClaim creates a new claim info from a checkpointed claim info state object. +func newClaimInfoFromState(state *state.ClaimInfoState) *ClaimInfo { + info := &ClaimInfo{ + ClaimInfoState: *state.DeepCopy(), + annotations: make(map[string][]kubecontainer.Annotation), + prepared: false, + } + for pluginName, devices := range info.CDIDevices { + annotations, _ := cdi.GenerateAnnotations(info.ClaimUID, info.DriverName, devices) + info.annotations[pluginName] = append(info.annotations[pluginName], annotations...) + } + return info +} +// setCDIDevices adds a set of CDI devices to the claim info. +func (info *ClaimInfo) setCDIDevices(pluginName string, cdiDevices []string) error { // NOTE: Passing CDI device names as annotations is a temporary solution // It will be removed after all runtimes are updated // to get CDI device names from the ContainerConfig.CDIDevices field @@ -69,6 +102,10 @@ func (info *ClaimInfo) addCDIDevices(pluginName string, cdiDevices []string) err info.CDIDevices = make(map[string][]string) } + if info.annotations == nil { + info.annotations = make(map[string][]kubecontainer.Annotation) + } + info.CDIDevices[pluginName] = cdiDevices info.annotations[pluginName] = annotations @@ -77,9 +114,6 @@ func (info *ClaimInfo) addCDIDevices(pluginName string, cdiDevices []string) err // annotationsAsList returns container annotations as a single list. func (info *ClaimInfo) annotationsAsList() []kubecontainer.Annotation { - info.RLock() - defer info.RUnlock() - var lst []kubecontainer.Annotation for _, v := range info.annotations { lst = append(lst, v...) @@ -87,53 +121,43 @@ func (info *ClaimInfo) annotationsAsList() []kubecontainer.Annotation { return lst } -// claimInfoCache is a cache of processed resource claims keyed by namespace + claim name. -type claimInfoCache struct { - sync.RWMutex - state state.CheckpointState - claimInfo map[string]*ClaimInfo +// cdiDevicesAsList returns a list of CDIDevices from the provided claim info. +func (info *ClaimInfo) cdiDevicesAsList() []kubecontainer.CDIDevice { + var cdiDevices []kubecontainer.CDIDevice + for _, devices := range info.CDIDevices { + for _, device := range devices { + cdiDevices = append(cdiDevices, kubecontainer.CDIDevice{Name: device}) + } + } + return cdiDevices } -func newClaimInfo(driverName, className string, claimUID types.UID, claimName, namespace string, podUIDs sets.Set[string], resourceHandles []resourcev1alpha2.ResourceHandle) *ClaimInfo { - claimInfoState := state.ClaimInfoState{ - DriverName: driverName, - ClassName: className, - ClaimUID: claimUID, - ClaimName: claimName, - Namespace: namespace, - PodUIDs: podUIDs, - ResourceHandles: resourceHandles, - } - claimInfo := ClaimInfo{ - ClaimInfoState: claimInfoState, - annotations: make(map[string][]kubecontainer.Annotation), - } - return &claimInfo +// addPodReference adds a pod reference to the claim info. +func (info *ClaimInfo) addPodReference(podUID types.UID) { + info.PodUIDs.Insert(string(podUID)) } -// newClaimInfoFromResourceClaim creates a new ClaimInfo object -func newClaimInfoFromResourceClaim(resourceClaim *resourcev1alpha2.ResourceClaim) *ClaimInfo { - // Grab the allocation.resourceHandles. If there are no - // allocation.resourceHandles, create a single resourceHandle with no - // content. This will trigger processing of this claim by a single - // kubelet plugin whose name matches resourceClaim.Status.DriverName. - resourceHandles := resourceClaim.Status.Allocation.ResourceHandles - if len(resourceHandles) == 0 { - resourceHandles = make([]resourcev1alpha2.ResourceHandle, 1) - } - - return newClaimInfo( - resourceClaim.Status.DriverName, - resourceClaim.Spec.ResourceClassName, - resourceClaim.UID, - resourceClaim.Name, - resourceClaim.Namespace, - make(sets.Set[string]), - resourceHandles, - ) +// hasPodReference checks if a pod reference exists in the claim info. +func (info *ClaimInfo) hasPodReference(podUID types.UID) bool { + return info.PodUIDs.Has(string(podUID)) } -// newClaimInfoCache is a function that returns an instance of the claimInfoCache. +// deletePodReference deletes a pod reference from the claim info. +func (info *ClaimInfo) deletePodReference(podUID types.UID) { + info.PodUIDs.Delete(string(podUID)) +} + +// setPrepared marks the claim info as prepared. +func (info *ClaimInfo) setPrepared() { + info.prepared = true +} + +// isPrepared checks if claim info is prepared or not. +func (info *ClaimInfo) isPrepared() bool { + return info.prepared +} + +// newClaimInfoCache creates a new claim info cache object, pre-populated from a checkpoint (if present). func newClaimInfoCache(stateDir, checkpointName string) (*claimInfoCache, error) { stateImpl, err := state.NewCheckpointState(stateDir, checkpointName) if err != nil { @@ -151,46 +175,48 @@ func newClaimInfoCache(stateDir, checkpointName string) (*claimInfoCache, error) } for _, entry := range curState { - info := newClaimInfo( - entry.DriverName, - entry.ClassName, - entry.ClaimUID, - entry.ClaimName, - entry.Namespace, - entry.PodUIDs, - entry.ResourceHandles, - ) - for pluginName, cdiDevices := range entry.CDIDevices { - err := info.addCDIDevices(pluginName, cdiDevices) - if err != nil { - return nil, fmt.Errorf("failed to add CDIDevices to claimInfo %+v: %+v", info, err) - } - } - cache.add(info) + info := newClaimInfoFromState(&entry) + cache.claimInfo[info.Namespace+"/"+info.ClaimName] = info } return cache, nil } -func (cache *claimInfoCache) add(res *ClaimInfo) { +// withLock runs a function while holding the claimInfoCache lock. +func (cache *claimInfoCache) withLock(f func() error) error { cache.Lock() defer cache.Unlock() - - cache.claimInfo[res.ClaimName+res.Namespace] = res + return f() } -func (cache *claimInfoCache) get(claimName, namespace string) *ClaimInfo { +// withRLock runs a function while holding the claimInfoCache rlock. +func (cache *claimInfoCache) withRLock(f func() error) error { cache.RLock() defer cache.RUnlock() - - return cache.claimInfo[claimName+namespace] + return f() } -func (cache *claimInfoCache) delete(claimName, namespace string) { - cache.Lock() - defer cache.Unlock() +// add adds a new claim info object into the claim info cache. +func (cache *claimInfoCache) add(info *ClaimInfo) *ClaimInfo { + cache.claimInfo[info.Namespace+"/"+info.ClaimName] = info + return info +} - delete(cache.claimInfo, claimName+namespace) +// contains checks to see if a specific claim info object is already in the cache. +func (cache *claimInfoCache) contains(claimName, namespace string) bool { + _, exists := cache.claimInfo[namespace+"/"+claimName] + return exists +} + +// get gets a specific claim info object from the cache. +func (cache *claimInfoCache) get(claimName, namespace string) (*ClaimInfo, bool) { + info, exists := cache.claimInfo[namespace+"/"+claimName] + return info, exists +} + +// delete deletes a specific claim info object from the cache. +func (cache *claimInfoCache) delete(claimName, namespace string) { + delete(cache.claimInfo, namespace+"/"+claimName) } // hasPodReference checks if there is at least one claim @@ -198,26 +224,19 @@ func (cache *claimInfoCache) delete(claimName, namespace string) { // This function is used indirectly by the status manager // to check if pod can enter termination status func (cache *claimInfoCache) hasPodReference(UID types.UID) bool { - cache.RLock() - defer cache.RUnlock() - for _, claimInfo := range cache.claimInfo { - if claimInfo.PodUIDs.Has(string(UID)) { + if claimInfo.hasPodReference(UID) { return true } } - return false } +// syncToCheckpoint syncs the full claim info cache state to a checkpoint. func (cache *claimInfoCache) syncToCheckpoint() error { - cache.RLock() - defer cache.RUnlock() - claimInfoStateList := make(state.ClaimInfoStateList, 0, len(cache.claimInfo)) for _, infoClaim := range cache.claimInfo { claimInfoStateList = append(claimInfoStateList, infoClaim.ClaimInfoState) } - return cache.state.Store(claimInfoStateList) } diff --git a/pkg/kubelet/cm/dra/claiminfo_test.go b/pkg/kubelet/cm/dra/claiminfo_test.go new file mode 100644 index 00000000000..58652cc605c --- /dev/null +++ b/pkg/kubelet/cm/dra/claiminfo_test.go @@ -0,0 +1,894 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dra + +import ( + "errors" + "fmt" + "path" + "reflect" + "sort" + "testing" + + "github.com/stretchr/testify/assert" + resourcev1alpha2 "k8s.io/api/resource/v1alpha2" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/kubernetes/pkg/kubelet/cm/dra/state" + kubecontainer "k8s.io/kubernetes/pkg/kubelet/container" +) + +// ClaimInfo test cases + +func TestNewClaimInfoFromClaim(t *testing.T) { + namespace := "test-namespace" + className := "test-class" + driverName := "test-plugin" + claimUID := types.UID("claim-uid") + claimName := "test-claim" + + for _, test := range []struct { + description string + claim *resourcev1alpha2.ResourceClaim + expectedResult *ClaimInfo + }{ + { + description: "successfully created object", + claim: &resourcev1alpha2.ResourceClaim{ + ObjectMeta: metav1.ObjectMeta{ + UID: claimUID, + Name: claimName, + Namespace: namespace, + }, + Status: resourcev1alpha2.ResourceClaimStatus{ + DriverName: driverName, + Allocation: &resourcev1alpha2.AllocationResult{ + ResourceHandles: []resourcev1alpha2.ResourceHandle{}, + }, + }, + Spec: resourcev1alpha2.ResourceClaimSpec{ + ResourceClassName: className, + }, + }, + expectedResult: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + DriverName: driverName, + ClassName: className, + ClaimUID: claimUID, + ClaimName: claimName, + Namespace: claimName, + PodUIDs: sets.New[string](), + ResourceHandles: []resourcev1alpha2.ResourceHandle{ + {}, + }, + CDIDevices: make(map[string][]string), + }, + }, + }, + { + description: "successfully created object with empty allocation", + claim: &resourcev1alpha2.ResourceClaim{ + ObjectMeta: metav1.ObjectMeta{ + UID: claimUID, + Name: claimName, + Namespace: namespace, + }, + Status: resourcev1alpha2.ResourceClaimStatus{ + DriverName: driverName, + Allocation: &resourcev1alpha2.AllocationResult{}, + }, + Spec: resourcev1alpha2.ResourceClaimSpec{ + ResourceClassName: className, + }, + }, + expectedResult: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + DriverName: driverName, + ClassName: className, + ClaimUID: claimUID, + ClaimName: claimName, + Namespace: claimName, + PodUIDs: sets.New[string](), + ResourceHandles: []resourcev1alpha2.ResourceHandle{ + {}, + }, + CDIDevices: make(map[string][]string), + }, + }, + }, + } { + t.Run(test.description, func(t *testing.T) { + result := newClaimInfoFromClaim(test.claim) + if reflect.DeepEqual(result, test.expectedResult) { + t.Errorf("Expected %v, but got %v", test.expectedResult, result) + } + }) + } +} + +func TestNewClaimInfoFromState(t *testing.T) { + for _, test := range []struct { + description string + state *state.ClaimInfoState + expectedResult *ClaimInfo + }{ + { + description: "successfully created object", + state: &state.ClaimInfoState{ + DriverName: "test-driver", + ClassName: "test-class", + ClaimUID: "test-uid", + ClaimName: "test-claim", + Namespace: "test-namespace", + PodUIDs: sets.New[string]("test-pod-uid"), + ResourceHandles: []resourcev1alpha2.ResourceHandle{}, + CDIDevices: map[string][]string{}, + }, + }, + } { + t.Run(test.description, func(t *testing.T) { + result := newClaimInfoFromState(test.state) + if reflect.DeepEqual(result, test.expectedResult) { + t.Errorf("Expected %v, but got %v", test.expectedResult, result) + } + }) + } +} + +func TestClaimInfoSetCDIDevices(t *testing.T) { + claimUID := types.UID("claim-uid") + pluginName := "test-plugin" + device := "vendor.com/device=device1" + annotationName := fmt.Sprintf("cdi.k8s.io/%s_%s", pluginName, claimUID) + for _, test := range []struct { + description string + claimInfo *ClaimInfo + devices []string + expectedCDIDevices map[string][]string + expectedAnnotations map[string][]kubecontainer.Annotation + wantErr bool + }{ + { + description: "successfully add one device", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + DriverName: pluginName, + ClaimUID: claimUID, + }, + }, + devices: []string{device}, + expectedCDIDevices: map[string][]string{ + pluginName: {device}, + }, + expectedAnnotations: map[string][]kubecontainer.Annotation{ + pluginName: { + { + Name: annotationName, + Value: device, + }, + }, + }, + }, + { + description: "empty list of devices", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + DriverName: pluginName, + ClaimUID: claimUID, + }, + }, + devices: []string{}, + expectedCDIDevices: map[string][]string{pluginName: {}}, + expectedAnnotations: map[string][]kubecontainer.Annotation{pluginName: nil}, + }, + { + description: "incorrect device format", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + DriverName: pluginName, + ClaimUID: claimUID, + }, + }, + devices: []string{"incorrect"}, + wantErr: true, + }, + } { + t.Run(test.description, func(t *testing.T) { + err := test.claimInfo.setCDIDevices(pluginName, test.devices) + if test.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, test.expectedCDIDevices, test.claimInfo.CDIDevices) + assert.Equal(t, test.expectedAnnotations, test.claimInfo.annotations) + }) + } +} + +func TestClaimInfoAnnotationsAsList(t *testing.T) { + for _, test := range []struct { + description string + claimInfo *ClaimInfo + expectedResult []kubecontainer.Annotation + }{ + { + description: "empty annotations", + claimInfo: &ClaimInfo{ + annotations: map[string][]kubecontainer.Annotation{}, + }, + }, + { + description: "nil annotations", + claimInfo: &ClaimInfo{}, + }, + { + description: "valid annotations", + claimInfo: &ClaimInfo{ + annotations: map[string][]kubecontainer.Annotation{ + "test-plugin1": { + { + Name: "cdi.k8s.io/test-plugin1_claim-uid1", + Value: "vendor.com/device=device1", + }, + { + Name: "cdi.k8s.io/test-plugin1_claim-uid2", + Value: "vendor.com/device=device2", + }, + }, + "test-plugin2": { + { + Name: "cdi.k8s.io/test-plugin2_claim-uid1", + Value: "vendor.com/device=device1", + }, + { + Name: "cdi.k8s.io/test-plugin2_claim-uid2", + Value: "vendor.com/device=device2", + }, + }, + }, + }, + expectedResult: []kubecontainer.Annotation{ + { + Name: "cdi.k8s.io/test-plugin1_claim-uid1", + Value: "vendor.com/device=device1", + }, + { + Name: "cdi.k8s.io/test-plugin1_claim-uid2", + Value: "vendor.com/device=device2", + }, + { + Name: "cdi.k8s.io/test-plugin2_claim-uid1", + Value: "vendor.com/device=device1", + }, + { + Name: "cdi.k8s.io/test-plugin2_claim-uid2", + Value: "vendor.com/device=device2", + }, + }, + }, + } { + t.Run(test.description, func(t *testing.T) { + result := test.claimInfo.annotationsAsList() + sort.Slice(result, func(i, j int) bool { + return result[i].Name < result[j].Name + }) + assert.Equal(t, test.expectedResult, result) + }) + } +} + +func TestClaimInfoCDIdevicesAsList(t *testing.T) { + for _, test := range []struct { + description string + claimInfo *ClaimInfo + expectedResult []kubecontainer.CDIDevice + }{ + { + description: "empty CDI devices", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + CDIDevices: map[string][]string{}, + }, + }, + }, + { + description: "nil CDI devices", + claimInfo: &ClaimInfo{}, + }, + { + description: "valid CDI devices", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + CDIDevices: map[string][]string{ + "test-plugin1": { + "vendor.com/device=device1", + "vendor.com/device=device2", + }, + "test-plugin2": { + "vendor.com/device=device1", + "vendor.com/device=device2", + }, + }, + }, + }, + expectedResult: []kubecontainer.CDIDevice{ + { + Name: "vendor.com/device=device1", + }, + { + Name: "vendor.com/device=device1", + }, + { + Name: "vendor.com/device=device2", + }, + { + Name: "vendor.com/device=device2", + }, + }, + }, + } { + t.Run(test.description, func(t *testing.T) { + result := test.claimInfo.cdiDevicesAsList() + sort.Slice(result, func(i, j int) bool { + return result[i].Name < result[j].Name + }) + assert.Equal(t, test.expectedResult, result) + }) + } +} +func TestClaimInfoAddPodReference(t *testing.T) { + podUID := types.UID("pod-uid") + for _, test := range []struct { + description string + claimInfo *ClaimInfo + expectedLen int + }{ + { + description: "successfully add pod reference", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + PodUIDs: sets.New[string](), + }, + }, + expectedLen: 1, + }, + { + description: "duplicate pod reference", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + PodUIDs: sets.New[string](string(podUID)), + }, + }, + expectedLen: 1, + }, + { + description: "duplicate pod reference", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + PodUIDs: sets.New[string]("pod-uid1"), + }, + }, + expectedLen: 2, + }, + } { + t.Run(test.description, func(t *testing.T) { + test.claimInfo.addPodReference(podUID) + assert.True(t, test.claimInfo.hasPodReference(podUID)) + assert.Len(t, test.claimInfo.PodUIDs, test.expectedLen) + }) + } +} + +func TestClaimInfoHasPodReference(t *testing.T) { + podUID := types.UID("pod-uid") + for _, test := range []struct { + description string + claimInfo *ClaimInfo + expectedResult bool + }{ + { + description: "claim doesn't reference pod", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + PodUIDs: sets.New[string](), + }, + }, + }, + { + description: "claim references pod", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + PodUIDs: sets.New[string](string(podUID)), + }, + }, + expectedResult: true, + }, + { + description: "empty claim info", + claimInfo: &ClaimInfo{}, + }, + } { + t.Run(test.description, func(t *testing.T) { + assert.Equal(t, test.claimInfo.hasPodReference(podUID), test.expectedResult) + }) + } +} + +func TestClaimInfoDeletePodReference(t *testing.T) { + podUID := types.UID("pod-uid") + for _, test := range []struct { + description string + claimInfo *ClaimInfo + }{ + { + description: "claim doesn't reference pod", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + PodUIDs: sets.New[string](), + }, + }, + }, + { + description: "claim references pod", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + PodUIDs: sets.New[string](string(podUID)), + }, + }, + }, + { + description: "empty claim info", + claimInfo: &ClaimInfo{}, + }, + } { + t.Run(test.description, func(t *testing.T) { + test.claimInfo.deletePodReference(podUID) + assert.False(t, test.claimInfo.hasPodReference(podUID)) + }) + } +} + +func TestClaimInfoSetPrepared(t *testing.T) { + for _, test := range []struct { + description string + claimInfo *ClaimInfo + }{ + { + description: "claim info is not prepared", + claimInfo: &ClaimInfo{ + prepared: false, + }, + }, + { + description: "claim info is prepared", + claimInfo: &ClaimInfo{ + prepared: true, + }, + }, + { + description: "empty claim info", + claimInfo: &ClaimInfo{}, + }, + } { + t.Run(test.description, func(t *testing.T) { + test.claimInfo.setPrepared() + assert.Equal(t, test.claimInfo.isPrepared(), true) + }) + } +} + +func TestClaimInfoIsPrepared(t *testing.T) { + for _, test := range []struct { + description string + claimInfo *ClaimInfo + expectedResult bool + }{ + { + description: "claim info is not prepared", + claimInfo: &ClaimInfo{ + prepared: false, + }, + expectedResult: false, + }, + { + description: "claim info is prepared", + claimInfo: &ClaimInfo{ + prepared: true, + }, + expectedResult: true, + }, + { + description: "empty claim info", + claimInfo: &ClaimInfo{}, + expectedResult: false, + }, + } { + t.Run(test.description, func(t *testing.T) { + assert.Equal(t, test.claimInfo.isPrepared(), test.expectedResult) + }) + } +} + +// claimInfoCache test cases +func TestNewClaimInfoCache(t *testing.T) { + for _, test := range []struct { + description string + stateDir string + checkpointName string + wantErr bool + }{ + { + description: "successfully created cache", + stateDir: t.TempDir(), + checkpointName: "test-checkpoint", + }, + { + description: "empty parameters", + wantErr: true, + }, + { + description: "empty checkpoint name", + stateDir: t.TempDir(), + wantErr: true, + }, + { + description: "incorrect checkpoint name", + stateDir: path.Join(t.TempDir(), "incorrect checkpoint"), + wantErr: true, + }, + } { + t.Run(test.description, func(t *testing.T) { + result, err := newClaimInfoCache(test.stateDir, test.checkpointName) + if test.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.NotNil(t, result) + }) + } +} + +func TestClaimInfoCacheWithLock(t *testing.T) { + for _, test := range []struct { + description string + funcGen func(cache *claimInfoCache) func() error + wantErr bool + }{ + { + description: "cache is locked inside a function", + funcGen: func(cache *claimInfoCache) func() error { + return func() error { + if cache.RWMutex.TryLock() { + return errors.New("Lock succeeded") + } + return nil + } + }, + }, + { + description: "cache is Rlocked inside a function", + funcGen: func(cache *claimInfoCache) func() error { + return func() error { + if cache.RWMutex.TryRLock() { + return errors.New("RLock succeeded") + } + return nil + } + }, + }, + { + description: "successfully called function", + funcGen: func(cache *claimInfoCache) func() error { + return func() error { + return nil + } + }, + }, + { + description: "erroring function", + funcGen: func(cache *claimInfoCache) func() error { + return func() error { + return errors.New("test error") + } + }, + wantErr: true, + }, + } { + t.Run(test.description, func(t *testing.T) { + cache, err := newClaimInfoCache(t.TempDir(), "test-checkpoint") + assert.NoError(t, err) + assert.NotNil(t, cache) + err = cache.withLock(test.funcGen(cache)) + if test.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} + +func TestClaimInfoCacheWithRLock(t *testing.T) { + for _, test := range []struct { + description string + funcGen func(cache *claimInfoCache) func() error + wantErr bool + }{ + { + description: "RLock-ed cache allows another RLock", + funcGen: func(cache *claimInfoCache) func() error { + return func() error { + if !cache.RWMutex.TryRLock() { + return errors.New("RLock failed") + } + return nil + } + }, + }, + { + description: "cache is locked inside a function", + funcGen: func(cache *claimInfoCache) func() error { + return func() error { + if cache.RWMutex.TryLock() { + return errors.New("Lock succeeded") + } + return nil + } + }, + }, + { + description: "successfully called function", + funcGen: func(cache *claimInfoCache) func() error { + return func() error { + return nil + } + }, + }, + { + description: "erroring function", + funcGen: func(cache *claimInfoCache) func() error { + return func() error { + return errors.New("test error") + } + }, + wantErr: true, + }, + } { + t.Run(test.description, func(t *testing.T) { + cache, err := newClaimInfoCache(t.TempDir(), "test-checkpoint") + assert.NoError(t, err) + assert.NotNil(t, cache) + err = cache.withRLock(test.funcGen(cache)) + if test.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} + +func TestClaimInfoCacheAdd(t *testing.T) { + for _, test := range []struct { + description string + claimInfo *ClaimInfo + }{ + { + description: "claimInfo successfully added", + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + ClaimName: "test-claim", + Namespace: "test-namespace", + }, + }, + }, + } { + t.Run(test.description, func(t *testing.T) { + cache, err := newClaimInfoCache(t.TempDir(), "test-checkpoint") + assert.NoError(t, err) + assert.NotNil(t, cache) + cache.add(test.claimInfo) + assert.True(t, cache.contains(test.claimInfo.ClaimName, test.claimInfo.Namespace)) + }) + } +} + +func TestClaimInfoCacheContains(t *testing.T) { + claimName := "test-claim" + namespace := "test-namespace" + for _, test := range []struct { + description string + claimInfo *ClaimInfo + claimInfoCache *claimInfoCache + expectedResult bool + }{ + { + description: "cache hit", + claimInfoCache: &claimInfoCache{ + claimInfo: map[string]*ClaimInfo{ + namespace + "/" + claimName: { + ClaimInfoState: state.ClaimInfoState{ + ClaimName: claimName, + Namespace: namespace, + }, + }, + }, + }, + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + ClaimName: claimName, + Namespace: namespace, + }, + }, + expectedResult: true, + }, + { + description: "cache miss", + claimInfoCache: &claimInfoCache{}, + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{ + ClaimName: claimName, + Namespace: namespace, + }, + }, + }, + { + description: "cache miss: empty cache and empty claim info", + claimInfoCache: &claimInfoCache{}, + claimInfo: &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{}, + }, + }, + } { + t.Run(test.description, func(t *testing.T) { + assert.Equal(t, test.expectedResult, test.claimInfoCache.contains(test.claimInfo.ClaimName, test.claimInfo.Namespace)) + }) + } +} + +func TestClaimInfoCacheGet(t *testing.T) { + claimName := "test-claim" + namespace := "test-namespace" + for _, test := range []struct { + description string + claimInfoCache *claimInfoCache + expectedNil bool + expectedExists bool + }{ + { + description: "cache hit", + claimInfoCache: &claimInfoCache{ + claimInfo: map[string]*ClaimInfo{ + namespace + "/" + claimName: { + ClaimInfoState: state.ClaimInfoState{ + ClaimName: claimName, + Namespace: namespace, + }, + }, + }, + }, + expectedExists: true, + }, + { + description: "cache miss", + claimInfoCache: &claimInfoCache{}, + expectedNil: true, + }, + } { + t.Run(test.description, func(t *testing.T) { + result, exists := test.claimInfoCache.get(claimName, namespace) + assert.Equal(t, test.expectedExists, exists) + assert.Equal(t, test.expectedNil, result == nil) + }) + } +} + +func TestClaimInfoCacheDelete(t *testing.T) { + claimName := "test-claim" + namespace := "test-namespace" + for _, test := range []struct { + description string + claimInfoCache *claimInfoCache + }{ + { + description: "item in cache", + claimInfoCache: &claimInfoCache{ + claimInfo: map[string]*ClaimInfo{ + claimName + namespace: { + ClaimInfoState: state.ClaimInfoState{ + ClaimName: claimName, + Namespace: namespace, + }, + }, + }, + }, + }, + { + description: "item not in cache", + claimInfoCache: &claimInfoCache{}, + }, + } { + t.Run(test.description, func(t *testing.T) { + test.claimInfoCache.delete(claimName, namespace) + assert.False(t, test.claimInfoCache.contains(claimName, namespace)) + }) + } +} + +func TestClaimInfoCacheHasPodReference(t *testing.T) { + claimName := "test-claim" + namespace := "test-namespace" + uid := types.UID("test-uid") + for _, test := range []struct { + description string + claimInfoCache *claimInfoCache + expectedResult bool + }{ + { + description: "uid is referenced", + claimInfoCache: &claimInfoCache{ + claimInfo: map[string]*ClaimInfo{ + claimName + namespace: { + ClaimInfoState: state.ClaimInfoState{ + ClaimName: claimName, + Namespace: namespace, + PodUIDs: sets.New[string](string(uid)), + }, + }, + }, + }, + expectedResult: true, + }, + { + description: "uid is not referenced", + claimInfoCache: &claimInfoCache{}, + }, + } { + t.Run(test.description, func(t *testing.T) { + assert.Equal(t, test.expectedResult, test.claimInfoCache.hasPodReference(uid)) + }) + } +} + +func TestSyncToCheckpoint(t *testing.T) { + for _, test := range []struct { + description string + stateDir string + checkpointName string + wantErr bool + }{ + { + description: "successfully checkpointed cache", + stateDir: t.TempDir(), + checkpointName: "test-checkpoint", + }, + } { + t.Run(test.description, func(t *testing.T) { + cache, err := newClaimInfoCache(test.stateDir, test.checkpointName) + assert.NoError(t, err) + err = cache.syncToCheckpoint() + if test.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} diff --git a/pkg/kubelet/cm/dra/manager.go b/pkg/kubelet/cm/dra/manager.go index ad9b17dfbd8..18806006031 100644 --- a/pkg/kubelet/cm/dra/manager.go +++ b/pkg/kubelet/cm/dra/manager.go @@ -19,27 +19,48 @@ package dra import ( "context" "fmt" + "time" v1 "k8s.io/api/core/v1" resourceapi "k8s.io/api/resource/v1alpha2" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/wait" clientset "k8s.io/client-go/kubernetes" "k8s.io/dynamic-resource-allocation/resourceclaim" "k8s.io/klog/v2" drapb "k8s.io/kubelet/pkg/apis/dra/v1alpha3" dra "k8s.io/kubernetes/pkg/kubelet/cm/dra/plugin" + "k8s.io/kubernetes/pkg/kubelet/config" kubecontainer "k8s.io/kubernetes/pkg/kubelet/container" ) // draManagerStateFileName is the file name where dra manager stores its state const draManagerStateFileName = "dra_manager_state" +// defaultReconcilePeriod is the default reconciliation period to keep all claim info state in sync. +const defaultReconcilePeriod = 60 * time.Second + +// ActivePodsFunc is a function that returns a list of pods to reconcile. +type ActivePodsFunc func() []*v1.Pod + // ManagerImpl is the structure in charge of managing DRA resource Plugins. type ManagerImpl struct { // cache contains cached claim info cache *claimInfoCache + // reconcilePeriod is the duration between calls to reconcileLoop. + reconcilePeriod time.Duration + + // activePods is a method for listing active pods on the node + // so all claim info state can be updated in the reconciliation loop. + activePods ActivePodsFunc + + // sourcesReady provides the readiness of kubelet configuration sources such as apiserver update readiness. + // We use it to determine when we can treat pods as inactive and react appropriately. + sourcesReady config.SourcesReady + // KubeClient reference kubeClient clientset.Interface } @@ -53,21 +74,84 @@ func NewManagerImpl(kubeClient clientset.Interface, stateFileDirectory string, n return nil, fmt.Errorf("failed to create claimInfo cache: %+v", err) } + // TODO: for now the reconcile period is not configurable. + // We should consider making it configurable in the future. + reconcilePeriod := defaultReconcilePeriod + manager := &ManagerImpl{ - cache: claimInfoCache, - kubeClient: kubeClient, + cache: claimInfoCache, + kubeClient: kubeClient, + reconcilePeriod: reconcilePeriod, + activePods: nil, + sourcesReady: nil, } return manager, nil } +// Start starts the reconcile loop of the manager. +func (m *ManagerImpl) Start(activePods ActivePodsFunc, sourcesReady config.SourcesReady) error { + m.activePods = activePods + m.sourcesReady = sourcesReady + go wait.Until(func() { m.reconcileLoop() }, m.reconcilePeriod, wait.NeverStop) + return nil +} + +// reconcileLoop ensures that any stale state in the manager's claimInfoCache gets periodically reconciled. +func (m *ManagerImpl) reconcileLoop() { + // Only once all sources are ready do we attempt to reconcile. + // This ensures that the call to m.activePods() below will succeed with + // the actual active pods list. + if m.sourcesReady == nil || !m.sourcesReady.AllReady() { + return + } + + // Get the full list of active pods. + activePods := sets.New[string]() + for _, p := range m.activePods() { + activePods.Insert(string(p.UID)) + } + + // Get the list of inactive pods still referenced by any claimInfos. + type podClaims struct { + uid types.UID + namespace string + claimNames []string + } + inactivePodClaims := make(map[string]*podClaims) + m.cache.RLock() + for _, claimInfo := range m.cache.claimInfo { + for podUID := range claimInfo.PodUIDs { + if activePods.Has(podUID) { + continue + } + if inactivePodClaims[podUID] == nil { + inactivePodClaims[podUID] = &podClaims{ + uid: types.UID(podUID), + namespace: claimInfo.Namespace, + claimNames: []string{}, + } + } + inactivePodClaims[podUID].claimNames = append(inactivePodClaims[podUID].claimNames, claimInfo.ClaimName) + } + } + m.cache.RUnlock() + + // Loop through all inactive pods and call UnprepareResources on them. + for _, podClaims := range inactivePodClaims { + if err := m.unprepareResources(podClaims.uid, podClaims.namespace, podClaims.claimNames); err != nil { + klog.ErrorS(err, "Unpreparing pod resources in reconcile loop", "podUID", podClaims.uid) + } + } +} + // PrepareResources attempts to prepare all of the required resource // plugin resources for the input container, issue NodePrepareResources rpc requests // for each new resource requirement, process their responses and update the cached // containerResources on success. func (m *ManagerImpl) PrepareResources(pod *v1.Pod) error { batches := make(map[string][]*drapb.Claim) - claimInfos := make(map[types.UID]*ClaimInfo) + resourceClaims := make(map[types.UID]*resourceapi.ResourceClaim) for i := range pod.Spec.ResourceClaims { podClaim := &pod.Spec.ResourceClaims[i] klog.V(3).InfoS("Processing resource", "podClaim", podClaim.Name, "pod", pod.Name) @@ -108,48 +192,55 @@ func (m *ManagerImpl) PrepareResources(pod *v1.Pod) error { continue } - claimInfo := m.cache.get(*claimName, pod.Namespace) - if claimInfo == nil { - // claim does not exist in cache, create new claimInfo object - // to be processed later. - claimInfo = newClaimInfoFromResourceClaim(resourceClaim) - } - - // We delay checkpointing of this change until this call - // returns successfully. It is OK to do this because we - // will only return successfully from this call if the - // checkpoint has succeeded. That means if the kubelet is - // ever restarted before this checkpoint succeeds, the pod - // whose resources are being prepared would never have - // started, so it's OK (actually correct) to not include it - // in the cache. - claimInfo.addPodReference(pod.UID) - - if claimInfo.prepared { - // Already prepared this claim, no need to prepare it again - continue - } - - // Loop through all plugins and prepare for calling NodePrepareResources. - for _, resourceHandle := range claimInfo.ResourceHandles { - // If no DriverName is provided in the resourceHandle, we - // use the DriverName from the status - pluginName := resourceHandle.DriverName - if pluginName == "" { - pluginName = resourceClaim.Status.DriverName + // Atomically perform some operations on the claimInfo cache. + err = m.cache.withLock(func() error { + // Get a reference to the claim info for this claim from the cache. + // If there isn't one yet, then add it to the cache. + claimInfo, exists := m.cache.get(resourceClaim.Name, resourceClaim.Namespace) + if !exists { + claimInfo = m.cache.add(newClaimInfoFromClaim(resourceClaim)) } - claim := &drapb.Claim{ - Namespace: resourceClaim.Namespace, - Uid: string(resourceClaim.UID), - Name: resourceClaim.Name, - ResourceHandle: resourceHandle.Data, + + // Add a reference to the current pod in the claim info. + claimInfo.addPodReference(pod.UID) + + // Checkpoint to ensure all claims we plan to prepare are tracked. + // If something goes wrong and the newly referenced pod gets + // deleted without a successful prepare call, we will catch + // that in the reconcile loop and take the appropriate action. + if err := m.cache.syncToCheckpoint(); err != nil { + return fmt.Errorf("failed to checkpoint claimInfo state: %w", err) } - if resourceHandle.StructuredData != nil { - claim.StructuredResourceHandle = []*resourceapi.StructuredResourceHandle{resourceHandle.StructuredData} + + // If this claim is already prepared, there is no need to prepare it again. + if claimInfo.isPrepared() { + return nil } - batches[pluginName] = append(batches[pluginName], claim) + + // This saved claim will be used to update ClaimInfo cache + // after NodePrepareResources GRPC succeeds + resourceClaims[claimInfo.ClaimUID] = resourceClaim + + // Loop through all plugins and prepare for calling NodePrepareResources. + for _, resourceHandle := range claimInfo.ResourceHandles { + claim := &drapb.Claim{ + Namespace: claimInfo.Namespace, + Uid: string(claimInfo.ClaimUID), + Name: claimInfo.ClaimName, + ResourceHandle: resourceHandle.Data, + } + if resourceHandle.StructuredData != nil { + claim.StructuredResourceHandle = []*resourceapi.StructuredResourceHandle{resourceHandle.StructuredData} + } + pluginName := resourceHandle.DriverName + batches[pluginName] = append(batches[pluginName], claim) + } + + return nil + }) + if err != nil { + return fmt.Errorf("locked cache operation: %w", err) } - claimInfos[resourceClaim.UID] = claimInfo } // Call NodePrepareResources for all claims in each batch. @@ -175,34 +266,22 @@ func (m *ManagerImpl) PrepareResources(pod *v1.Pod) error { return fmt.Errorf("NodePrepareResources failed for claim %s/%s: %s", reqClaim.Namespace, reqClaim.Name, result.Error) } - claimInfo := claimInfos[types.UID(claimUID)] + claim := resourceClaims[types.UID(claimUID)] - // Add the CDI Devices returned by NodePrepareResources to - // the claimInfo object. - err = claimInfo.addCDIDevices(pluginName, result.GetCDIDevices()) + // Add the prepared CDI devices to the claim info + err := m.cache.withLock(func() error { + info, exists := m.cache.get(claim.Name, claim.Namespace) + if !exists { + return fmt.Errorf("unable to get claim info for claim %s in namespace %s", claim.Name, claim.Namespace) + } + if err := info.setCDIDevices(pluginName, result.GetCDIDevices()); err != nil { + return fmt.Errorf("unable to add CDI devices for plugin %s of claim %s in namespace %s", pluginName, claim.Name, claim.Namespace) + } + return nil + }) if err != nil { - return fmt.Errorf("failed to add CDIDevices to claimInfo %+v: %+v", claimInfo, err) + return fmt.Errorf("locked cache operation: %w", err) } - // mark claim as (successfully) prepared by manager, so next time we don't prepare it. - claimInfo.prepared = true - - // TODO: We (re)add the claimInfo object to the cache and - // sync it to the checkpoint *after* the - // NodePrepareResources call has completed. This will cause - // issues if the kubelet gets restarted between - // NodePrepareResources and syncToCheckpoint. It will result - // in not calling NodeUnprepareResources for this claim - // because no claimInfo will be synced back to the cache - // for it after the restart. We need to resolve this issue - // before moving to beta. - m.cache.add(claimInfo) - } - - // Checkpoint to reduce redundant calls to - // NodePrepareResources after a kubelet restart. - err = m.cache.syncToCheckpoint() - if err != nil { - return fmt.Errorf("failed to checkpoint claimInfo state, err: %+v", err) } unfinished := len(claims) - len(response.Claims) @@ -210,11 +289,30 @@ func (m *ManagerImpl) PrepareResources(pod *v1.Pod) error { return fmt.Errorf("NodePrepareResources left out %d claims", unfinished) } } - // Checkpoint to capture all of the previous addPodReference() calls. - err := m.cache.syncToCheckpoint() + + // Atomically perform some operations on the claimInfo cache. + err := m.cache.withLock(func() error { + // Mark all pod claims as prepared. + for _, claim := range resourceClaims { + info, exists := m.cache.get(claim.Name, claim.Namespace) + if !exists { + return fmt.Errorf("unable to get claim info for claim %s in namespace %s", claim.Name, claim.Namespace) + } + info.setPrepared() + } + + // Checkpoint to ensure all prepared claims are tracked with their list + // of CDI devices attached. + if err := m.cache.syncToCheckpoint(); err != nil { + return fmt.Errorf("failed to checkpoint claimInfo state: %w", err) + } + + return nil + }) if err != nil { - return fmt.Errorf("failed to checkpoint claimInfo state, err: %+v", err) + return fmt.Errorf("locked cache operation: %w", err) } + return nil } @@ -277,21 +375,25 @@ func (m *ManagerImpl) GetResources(pod *v1.Pod, container *v1.Container) (*Conta continue } - claimInfo := m.cache.get(*claimName, pod.Namespace) - if claimInfo == nil { - return nil, fmt.Errorf("unable to get resource for namespace: %s, claim: %s", pod.Namespace, *claimName) - } - - claimInfo.RLock() - claimAnnotations := claimInfo.annotationsAsList() - klog.V(3).InfoS("Add resource annotations", "claim", *claimName, "annotations", claimAnnotations) - annotations = append(annotations, claimAnnotations...) - for _, devices := range claimInfo.CDIDevices { - for _, device := range devices { - cdiDevices = append(cdiDevices, kubecontainer.CDIDevice{Name: device}) + err := m.cache.withRLock(func() error { + claimInfo, exists := m.cache.get(*claimName, pod.Namespace) + if !exists { + return fmt.Errorf("unable to get claim info for claim %s in namespace %s", *claimName, pod.Namespace) } + + claimAnnotations := claimInfo.annotationsAsList() + klog.V(3).InfoS("Add resource annotations", "claim", *claimName, "annotations", claimAnnotations) + annotations = append(annotations, claimAnnotations...) + + devices := claimInfo.cdiDevicesAsList() + klog.V(3).InfoS("Add CDI devices", "claim", *claimName, "CDI devices", devices) + cdiDevices = append(cdiDevices, devices...) + + return nil + }) + if err != nil { + return nil, fmt.Errorf("locked cache operation: %w", err) } - claimInfo.RUnlock() } } @@ -303,60 +405,73 @@ func (m *ManagerImpl) GetResources(pod *v1.Pod, container *v1.Container) (*Conta // As such, calls to the underlying NodeUnprepareResource API are skipped for claims that have // already been successfully unprepared. func (m *ManagerImpl) UnprepareResources(pod *v1.Pod) error { - batches := make(map[string][]*drapb.Claim) - claimInfos := make(map[types.UID]*ClaimInfo) + var claimNames []string for i := range pod.Spec.ResourceClaims { claimName, _, err := resourceclaim.Name(pod, &pod.Spec.ResourceClaims[i]) if err != nil { return fmt.Errorf("unprepare resource claim: %v", err) } - // The claim name might be nil if no underlying resource claim // was generated for the referenced claim. There are valid use // cases when this might happen, so we simply skip it. if claimName == nil { continue } + claimNames = append(claimNames, *claimName) + } + return m.unprepareResources(pod.UID, pod.Namespace, claimNames) +} - claimInfo := m.cache.get(*claimName, pod.Namespace) +func (m *ManagerImpl) unprepareResources(podUID types.UID, namespace string, claimNames []string) error { + batches := make(map[string][]*drapb.Claim) + claimNamesMap := make(map[types.UID]string) + for _, claimName := range claimNames { + // Atomically perform some operations on the claimInfo cache. + err := m.cache.withLock(func() error { + // Get the claim info from the cache + claimInfo, exists := m.cache.get(claimName, namespace) - // Skip calling NodeUnprepareResource if claim info is not cached - if claimInfo == nil { - continue - } - - // Skip calling NodeUnprepareResource if other pods are still referencing it - if len(claimInfo.PodUIDs) > 1 { - // We delay checkpointing of this change until this call returns successfully. - // It is OK to do this because we will only return successfully from this call if - // the checkpoint has succeeded. That means if the kubelet is ever restarted - // before this checkpoint succeeds, we will simply call into this (idempotent) - // function again. - claimInfo.deletePodReference(pod.UID) - continue - } - - // Loop through all plugins and prepare for calling NodeUnprepareResources. - for _, resourceHandle := range claimInfo.ResourceHandles { - // If no DriverName is provided in the resourceHandle, we - // use the DriverName from the status - pluginName := resourceHandle.DriverName - if pluginName == "" { - pluginName = claimInfo.DriverName + // Skip calling NodeUnprepareResource if claim info is not cached + if !exists { + return nil } - claim := &drapb.Claim{ - Namespace: claimInfo.Namespace, - Uid: string(claimInfo.ClaimUID), - Name: claimInfo.ClaimName, - ResourceHandle: resourceHandle.Data, + // Skip calling NodeUnprepareResource if other pods are still referencing it + if len(claimInfo.PodUIDs) > 1 { + // We delay checkpointing of this change until + // UnprepareResources returns successfully. It is OK to do + // this because we will only return successfully from this call + // if the checkpoint has succeeded. That means if the kubelet + // is ever restarted before this checkpoint succeeds, we will + // simply call into this (idempotent) function again. + claimInfo.deletePodReference(podUID) + return nil } - if resourceHandle.StructuredData != nil { - claim.StructuredResourceHandle = []*resourceapi.StructuredResourceHandle{resourceHandle.StructuredData} + + // This claimInfo name will be used to update ClaimInfo cache + // after NodeUnprepareResources GRPC succeeds + claimNamesMap[claimInfo.ClaimUID] = claimInfo.ClaimName + + // Loop through all plugins and prepare for calling NodeUnprepareResources. + for _, resourceHandle := range claimInfo.ResourceHandles { + claim := &drapb.Claim{ + Namespace: claimInfo.Namespace, + Uid: string(claimInfo.ClaimUID), + Name: claimInfo.ClaimName, + ResourceHandle: resourceHandle.Data, + } + if resourceHandle.StructuredData != nil { + claim.StructuredResourceHandle = []*resourceapi.StructuredResourceHandle{resourceHandle.StructuredData} + } + pluginName := resourceHandle.DriverName + batches[pluginName] = append(batches[pluginName], claim) } - batches[pluginName] = append(batches[pluginName], claim) + + return nil + }) + if err != nil { + return fmt.Errorf("locked cache operation: %w", err) } - claimInfos[claimInfo.ClaimUID] = claimInfo } // Call NodeUnprepareResources for all claims in each batch. @@ -382,20 +497,6 @@ func (m *ManagerImpl) UnprepareResources(pod *v1.Pod) error { if result.GetError() != "" { return fmt.Errorf("NodeUnprepareResources failed for claim %s/%s: %s", reqClaim.Namespace, reqClaim.Name, result.Error) } - - // Delete last pod UID only if unprepare succeeds. - // This ensures that the status manager doesn't enter termination status - // for the pod. This logic is implemented in - // m.PodMightNeedToUnprepareResources and claimInfo.hasPodReference. - claimInfo := claimInfos[types.UID(claimUID)] - claimInfo.deletePodReference(pod.UID) - m.cache.delete(claimInfo.ClaimName, pod.Namespace) - } - - // Checkpoint to reduce redundant calls to NodeUnprepareResources after a kubelet restart. - err = m.cache.syncToCheckpoint() - if err != nil { - return fmt.Errorf("failed to checkpoint claimInfo state, err: %+v", err) } unfinished := len(claims) - len(response.Claims) @@ -404,21 +505,35 @@ func (m *ManagerImpl) UnprepareResources(pod *v1.Pod) error { } } - // Checkpoint to capture all of the previous deletePodReference() calls. - err := m.cache.syncToCheckpoint() + // Atomically perform some operations on the claimInfo cache. + err := m.cache.withLock(func() error { + // Delete all claimInfos from the cache that have just been unprepared. + for _, claimName := range claimNamesMap { + m.cache.delete(claimName, namespace) + } + + // Atomically sync the cache back to the checkpoint. + if err := m.cache.syncToCheckpoint(); err != nil { + return fmt.Errorf("failed to checkpoint claimInfo state: %w", err) + } + return nil + }) if err != nil { - return fmt.Errorf("failed to checkpoint claimInfo state, err: %+v", err) + return fmt.Errorf("locked cache operation: %w", err) } + return nil } // PodMightNeedToUnprepareResources returns true if the pod might need to // unprepare resources func (m *ManagerImpl) PodMightNeedToUnprepareResources(UID types.UID) bool { + m.cache.Lock() + defer m.cache.Unlock() return m.cache.hasPodReference(UID) } -// GetCongtainerClaimInfos gets Container's ClaimInfo +// GetContainerClaimInfos gets Container's ClaimInfo func (m *ManagerImpl) GetContainerClaimInfos(pod *v1.Pod, container *v1.Container) ([]*ClaimInfo, error) { claimInfos := make([]*ClaimInfo, 0, len(pod.Spec.ResourceClaims)) @@ -432,11 +547,18 @@ func (m *ManagerImpl) GetContainerClaimInfos(pod *v1.Pod, container *v1.Containe if podResourceClaim.Name != claim.Name { continue } - claimInfo := m.cache.get(*claimName, pod.Namespace) - if claimInfo == nil { - return nil, fmt.Errorf("unable to get resource for namespace: %s, claim: %s", pod.Namespace, *claimName) + + err := m.cache.withRLock(func() error { + claimInfo, exists := m.cache.get(*claimName, pod.Namespace) + if !exists { + return fmt.Errorf("unable to get claim info for claim %s in namespace %s", *claimName, pod.Namespace) + } + claimInfos = append(claimInfos, claimInfo.DeepCopy()) + return nil + }) + if err != nil { + return nil, fmt.Errorf("locked cache operation: %w", err) } - claimInfos = append(claimInfos, claimInfo) } } return claimInfos, nil diff --git a/pkg/kubelet/cm/dra/manager_test.go b/pkg/kubelet/cm/dra/manager_test.go index ffda44dd828..eaf05760fac 100644 --- a/pkg/kubelet/cm/dra/manager_test.go +++ b/pkg/kubelet/cm/dra/manager_test.go @@ -22,6 +22,7 @@ import ( "net" "os" "path/filepath" + "sync" "sync/atomic" "testing" "time" @@ -31,6 +32,7 @@ import ( v1 "k8s.io/api/core/v1" resourcev1alpha2 "k8s.io/api/resource/v1alpha2" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/client-go/kubernetes/fake" "k8s.io/dynamic-resource-allocation/resourceclaim" @@ -62,6 +64,18 @@ func (s *fakeDRADriverGRPCServer) NodePrepareResources(ctx context.Context, req time.Sleep(*s.timeout) } + if s.prepareResourcesResponse == nil { + deviceName := "claim-" + req.Claims[0].Uid + result := s.driverName + "/" + driverClassName + "=" + deviceName + return &drapbv1.NodePrepareResourcesResponse{ + Claims: map[string]*drapbv1.NodePrepareResourceResponse{ + req.Claims[0].Uid: { + CDIDevices: []string{result}, + }, + }, + }, nil + } + return s.prepareResourcesResponse, nil } @@ -72,6 +86,14 @@ func (s *fakeDRADriverGRPCServer) NodeUnprepareResources(ctx context.Context, re time.Sleep(*s.timeout) } + if s.unprepareResourcesResponse == nil { + return &drapbv1.NodeUnprepareResourcesResponse{ + Claims: map[string]*drapbv1.NodeUnprepareResourceResponse{ + req.Claims[0].Uid: {}, + }, + }, nil + } + return s.unprepareResourcesResponse, nil } @@ -789,7 +811,7 @@ func TestPrepareResources(t *testing.T) { DriverName: driverName, Allocation: &resourcev1alpha2.AllocationResult{ ResourceHandles: []resourcev1alpha2.ResourceHandle{ - {Data: "test-data"}, + {Data: "test-data", DriverName: driverName}, }, }, ReservedFor: []resourcev1alpha2.ResourceClaimConsumerReference{ @@ -839,16 +861,13 @@ func TestPrepareResources(t *testing.T) { }, claimInfo: &ClaimInfo{ ClaimInfoState: state.ClaimInfoState{ - DriverName: driverName, - ClassName: "test-class", - ClaimName: "test-pod-claim", - ClaimUID: "test-reserved", - Namespace: "test-namespace", - PodUIDs: sets.Set[string]{"test-reserved": sets.Empty{}}, - CDIDevices: map[string][]string{ - driverName: {fmt.Sprintf("%s/%s=some-device", driverName, driverClassName)}, - }, - ResourceHandles: []resourcev1alpha2.ResourceHandle{{Data: "test-data"}}, + DriverName: driverName, + ClassName: "test-class", + ClaimName: "test-pod-claim", + ClaimUID: "test-reserved", + Namespace: "test-namespace", + PodUIDs: sets.Set[string]{"test-reserved": sets.Empty{}}, + ResourceHandles: []resourcev1alpha2.ResourceHandle{{Data: "test-data", DriverName: driverName}}, }, annotations: make(map[string][]kubecontainer.Annotation), prepared: false, @@ -866,7 +885,7 @@ func TestPrepareResources(t *testing.T) { DriverName: driverName, Allocation: &resourcev1alpha2.AllocationResult{ ResourceHandles: []resourcev1alpha2.ResourceHandle{ - {Data: "test-data"}, + {Data: "test-data", DriverName: driverName}, }, }, ReservedFor: []resourcev1alpha2.ResourceClaimConsumerReference{ @@ -940,8 +959,8 @@ func TestPrepareResources(t *testing.T) { if err != nil { t.Fatal(err) } - claimInfo := manager.cache.get(*claimName, test.pod.Namespace) - if claimInfo == nil { + claimInfo, ok := manager.cache.get(*claimName, test.pod.Namespace) + if !ok { t.Fatalf("claimInfo not found in cache for claim %s", *claimName) } if claimInfo.DriverName != test.resourceClaim.Status.DriverName { @@ -1316,8 +1335,7 @@ func TestUnprepareResources(t *testing.T) { if err != nil { t.Fatal(err) } - claimInfo := manager.cache.get(*claimName, test.pod.Namespace) - if claimInfo != nil { + if manager.cache.contains(*claimName, test.pod.Namespace) { t.Fatalf("claimInfo still found in cache after calling UnprepareResources") } }) @@ -1337,16 +1355,20 @@ func TestPodMightNeedToUnprepareResources(t *testing.T) { cache: cache, } - podUID := sets.Set[string]{} - podUID.Insert("test-pod-uid") - manager.cache.add(&ClaimInfo{ - ClaimInfoState: state.ClaimInfoState{PodUIDs: podUID, ClaimName: "test-claim", Namespace: "test-namespace"}, - }) + claimName := "test-claim" + podUID := "test-pod-uid" + namespace := "test-namespace" - testClaimInfo := manager.cache.get("test-claim", "test-namespace") - testClaimInfo.addPodReference("test-pod-uid") - - manager.PodMightNeedToUnprepareResources("test-pod-uid") + claimInfo := &ClaimInfo{ + ClaimInfoState: state.ClaimInfoState{PodUIDs: sets.New(podUID), ClaimName: claimName, Namespace: namespace}, + } + manager.cache.add(claimInfo) + if !manager.cache.contains(claimName, namespace) { + t.Fatalf("failed to get claimInfo from cache for claim name %s, namespace %s: err:%v", claimName, namespace, err) + } + claimInfo.addPodReference(types.UID(podUID)) + needsUnprepare := manager.PodMightNeedToUnprepareResources(types.UID(podUID)) + assert.True(t, needsUnprepare) } func TestGetContainerClaimInfos(t *testing.T) { @@ -1428,3 +1450,116 @@ func TestGetContainerClaimInfos(t *testing.T) { }) } } + +// TestParallelPrepareUnprepareResources calls PrepareResources and UnprepareResources APIs in parallel +// to detect possible data races +func TestParallelPrepareUnprepareResources(t *testing.T) { + // Setup and register fake DRA driver + draServerInfo, err := setupFakeDRADriverGRPCServer(false, nil, nil, nil) + if err != nil { + t.Fatal(err) + } + defer draServerInfo.teardownFn() + + plg := plugin.NewRegistrationHandler(nil, getFakeNode) + if err := plg.RegisterPlugin(driverName, draServerInfo.socketName, []string{"1.27"}, nil); err != nil { + t.Fatalf("failed to register plugin %s, err: %v", driverName, err) + } + defer plg.DeRegisterPlugin(driverName) + + // Create ClaimInfo cache + cache, err := newClaimInfoCache(t.TempDir(), draManagerStateFileName) + if err != nil { + t.Errorf("failed to newClaimInfoCache, err: %+v", err) + return + } + + // Create fake Kube client and DRA manager + fakeKubeClient := fake.NewSimpleClientset() + manager := &ManagerImpl{kubeClient: fakeKubeClient, cache: cache} + + // Call PrepareResources in parallel + var wgSync, wgStart sync.WaitGroup // groups to sync goroutines + numGoroutines := 30 + wgSync.Add(numGoroutines) + wgStart.Add(1) + for i := 0; i < numGoroutines; i++ { + go func(t *testing.T, goRoutineNum int) { + defer wgSync.Done() + wgStart.Wait() // Wait to start all goroutines at the same time + + var err error + nameSpace := "test-namespace-parallel" + claimName := fmt.Sprintf("test-pod-claim-%d", goRoutineNum) + podUID := types.UID(fmt.Sprintf("test-reserved-%d", goRoutineNum)) + pod := &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: fmt.Sprintf("test-pod-%d", goRoutineNum), + Namespace: nameSpace, + UID: podUID, + }, + Spec: v1.PodSpec{ + ResourceClaims: []v1.PodResourceClaim{ + { + Name: claimName, + Source: v1.ClaimSource{ResourceClaimName: func() *string { + s := claimName + return &s + }()}, + }, + }, + Containers: []v1.Container{ + { + Resources: v1.ResourceRequirements{ + Claims: []v1.ResourceClaim{ + { + Name: claimName, + }, + }, + }, + }, + }, + }, + } + resourceClaim := &resourcev1alpha2.ResourceClaim{ + ObjectMeta: metav1.ObjectMeta{ + Name: claimName, + Namespace: nameSpace, + UID: types.UID(fmt.Sprintf("claim-%d", goRoutineNum)), + }, + Spec: resourcev1alpha2.ResourceClaimSpec{ + ResourceClassName: "test-class", + }, + Status: resourcev1alpha2.ResourceClaimStatus{ + DriverName: driverName, + Allocation: &resourcev1alpha2.AllocationResult{ + ResourceHandles: []resourcev1alpha2.ResourceHandle{ + {Data: "test-data", DriverName: driverName}, + }, + }, + ReservedFor: []resourcev1alpha2.ResourceClaimConsumerReference{ + {UID: podUID}, + }, + }, + } + + if _, err = fakeKubeClient.ResourceV1alpha2().ResourceClaims(pod.Namespace).Create(context.Background(), resourceClaim, metav1.CreateOptions{}); err != nil { + t.Errorf("failed to create ResourceClaim %s: %+v", resourceClaim.Name, err) + return + } + + if err = manager.PrepareResources(pod); err != nil { + t.Errorf("pod: %s: PrepareResources failed: %+v", pod.Name, err) + return + } + + if err = manager.UnprepareResources(pod); err != nil { + t.Errorf("pod: %s: UnprepareResources failed: %+v", pod.Name, err) + return + } + + }(t, i) + } + wgStart.Done() // Start executing goroutines + wgSync.Wait() // Wait for all goroutines to finish +} diff --git a/pkg/kubelet/cm/dra/state/state_checkpoint.go b/pkg/kubelet/cm/dra/state/state_checkpoint.go index a391f0a13ca..a82f6b11bb4 100644 --- a/pkg/kubelet/cm/dra/state/state_checkpoint.go +++ b/pkg/kubelet/cm/dra/state/state_checkpoint.go @@ -36,6 +36,7 @@ type CheckpointState interface { } // ClaimInfoState is used to store claim info state in a checkpoint +// +k8s:deepcopy-gen=true type ClaimInfoState struct { // Name of the DRA driver DriverName string diff --git a/pkg/kubelet/cm/dra/state/zz_generated.deepcopy.go b/pkg/kubelet/cm/dra/state/zz_generated.deepcopy.go new file mode 100644 index 00000000000..d27ecf60883 --- /dev/null +++ b/pkg/kubelet/cm/dra/state/zz_generated.deepcopy.go @@ -0,0 +1,72 @@ +//go:build !ignore_autogenerated +// +build !ignore_autogenerated + +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Code generated by deepcopy-gen. DO NOT EDIT. + +package state + +import ( + v1alpha2 "k8s.io/api/resource/v1alpha2" + sets "k8s.io/apimachinery/pkg/util/sets" +) + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ClaimInfoState) DeepCopyInto(out *ClaimInfoState) { + *out = *in + if in.PodUIDs != nil { + in, out := &in.PodUIDs, &out.PodUIDs + *out = make(sets.Set[string], len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } + if in.ResourceHandles != nil { + in, out := &in.ResourceHandles, &out.ResourceHandles + *out = make([]v1alpha2.ResourceHandle, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.CDIDevices != nil { + in, out := &in.CDIDevices, &out.CDIDevices + *out = make(map[string][]string, len(*in)) + for key, val := range *in { + var outVal []string + if val == nil { + (*out)[key] = nil + } else { + in, out := &val, &outVal + *out = make([]string, len(*in)) + copy(*out, *in) + } + (*out)[key] = outVal + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ClaimInfoState. +func (in *ClaimInfoState) DeepCopy() *ClaimInfoState { + if in == nil { + return nil + } + out := new(ClaimInfoState) + in.DeepCopyInto(out) + return out +} diff --git a/pkg/kubelet/cm/dra/types.go b/pkg/kubelet/cm/dra/types.go index 58c8ca0dd65..e009e952eb4 100644 --- a/pkg/kubelet/cm/dra/types.go +++ b/pkg/kubelet/cm/dra/types.go @@ -19,11 +19,16 @@ package dra import ( v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/kubernetes/pkg/kubelet/config" kubecontainer "k8s.io/kubernetes/pkg/kubelet/container" ) // Manager manages all the DRA resource plugins running on a node. type Manager interface { + // Start starts the reconcile loop of the manager. + // This will ensure that all claims are unprepared even if pods get deleted unexpectedly. + Start(activePods ActivePodsFunc, sourcesReady config.SourcesReady) error + // PrepareResources prepares resources for a pod. // It communicates with the DRA resource plugin to prepare resources. PrepareResources(pod *v1.Pod) error diff --git a/pkg/kubelet/cm/dra/zz_generated.deepcopy.go b/pkg/kubelet/cm/dra/zz_generated.deepcopy.go new file mode 100644 index 00000000000..cc10fdaf53e --- /dev/null +++ b/pkg/kubelet/cm/dra/zz_generated.deepcopy.go @@ -0,0 +1,58 @@ +//go:build !ignore_autogenerated +// +build !ignore_autogenerated + +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Code generated by deepcopy-gen. DO NOT EDIT. + +package dra + +import ( + container "k8s.io/kubernetes/pkg/kubelet/container" +) + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ClaimInfo) DeepCopyInto(out *ClaimInfo) { + *out = *in + in.ClaimInfoState.DeepCopyInto(&out.ClaimInfoState) + if in.annotations != nil { + in, out := &in.annotations, &out.annotations + *out = make(map[string][]container.Annotation, len(*in)) + for key, val := range *in { + var outVal []container.Annotation + if val == nil { + (*out)[key] = nil + } else { + in, out := &val, &outVal + *out = make([]container.Annotation, len(*in)) + copy(*out, *in) + } + (*out)[key] = outVal + } + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ClaimInfo. +func (in *ClaimInfo) DeepCopy() *ClaimInfo { + if in == nil { + return nil + } + out := new(ClaimInfo) + in.DeepCopyInto(out) + return out +}