initial copy from DRA example driver

This commit is contained in:
dougbtv
2025-03-26 13:31:09 -04:00
parent 2a91646eaf
commit e420284885
7 changed files with 1007 additions and 0 deletions

View File

@@ -0,0 +1,137 @@
/*
* Copyright 2023 The Kubernetes Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"fmt"
"os"
"sigs.k8s.io/dra-example-driver/pkg/consts"
cdiapi "tags.cncf.io/container-device-interface/pkg/cdi"
cdiparser "tags.cncf.io/container-device-interface/pkg/parser"
cdispec "tags.cncf.io/container-device-interface/specs-go"
)
const (
cdiVendor = "k8s." + consts.DriverName
cdiClass = "gpu"
cdiKind = cdiVendor + "/" + cdiClass
cdiCommonDeviceName = "common"
)
type CDIHandler struct {
cache *cdiapi.Cache
}
func NewCDIHandler(config *Config) (*CDIHandler, error) {
cache, err := cdiapi.NewCache(
cdiapi.WithSpecDirs(config.flags.cdiRoot),
)
if err != nil {
return nil, fmt.Errorf("unable to create a new CDI cache: %w", err)
}
handler := &CDIHandler{
cache: cache,
}
return handler, nil
}
func (cdi *CDIHandler) CreateCommonSpecFile() error {
spec := &cdispec.Spec{
Kind: cdiKind,
Devices: []cdispec.Device{
{
Name: cdiCommonDeviceName,
ContainerEdits: cdispec.ContainerEdits{
Env: []string{
fmt.Sprintf("KUBERNETES_NODE_NAME=%s", os.Getenv("NODE_NAME")),
fmt.Sprintf("DRA_RESOURCE_DRIVER_NAME=%s", consts.DriverName),
},
},
},
},
}
minVersion, err := cdiapi.MinimumRequiredVersion(spec)
if err != nil {
return fmt.Errorf("failed to get minimum required CDI spec version: %v", err)
}
spec.Version = minVersion
specName, err := cdiapi.GenerateNameForTransientSpec(spec, cdiCommonDeviceName)
if err != nil {
return fmt.Errorf("failed to generate Spec name: %w", err)
}
return cdi.cache.WriteSpec(spec, specName)
}
func (cdi *CDIHandler) CreateClaimSpecFile(claimUID string, devices PreparedDevices) error {
specName := cdiapi.GenerateTransientSpecName(cdiVendor, cdiClass, claimUID)
spec := &cdispec.Spec{
Kind: cdiKind,
Devices: []cdispec.Device{},
}
for _, device := range devices {
claimEdits := cdiapi.ContainerEdits{
ContainerEdits: &cdispec.ContainerEdits{
Env: []string{
fmt.Sprintf("GPU_DEVICE_%s_RESOURCE_CLAIM=%s", device.DeviceName[4:], claimUID),
},
},
}
claimEdits.Append(device.ContainerEdits)
cdiDevice := cdispec.Device{
Name: fmt.Sprintf("%s-%s", claimUID, device.DeviceName),
ContainerEdits: *claimEdits.ContainerEdits,
}
spec.Devices = append(spec.Devices, cdiDevice)
}
minVersion, err := cdiapi.MinimumRequiredVersion(spec)
if err != nil {
return fmt.Errorf("failed to get minimum required CDI spec version: %v", err)
}
spec.Version = minVersion
return cdi.cache.WriteSpec(spec, specName)
}
func (cdi *CDIHandler) DeleteClaimSpecFile(claimUID string) error {
specName := cdiapi.GenerateTransientSpecName(cdiVendor, cdiClass, claimUID)
return cdi.cache.RemoveSpec(specName)
}
func (cdi *CDIHandler) GetClaimDevices(claimUID string, devices []string) []string {
cdiDevices := []string{
cdiparser.QualifiedName(cdiVendor, cdiClass, cdiCommonDeviceName),
}
for _, device := range devices {
cdiDevice := cdiparser.QualifiedName(cdiVendor, cdiClass, fmt.Sprintf("%s-%s", claimUID, device))
cdiDevices = append(cdiDevices, cdiDevice)
}
return cdiDevices
}

View File

@@ -0,0 +1,53 @@
package main
import (
"encoding/json"
"k8s.io/kubernetes/pkg/kubelet/checkpointmanager/checksum"
)
type Checkpoint struct {
Checksum checksum.Checksum `json:"checksum"`
V1 *CheckpointV1 `json:"v1,omitempty"`
}
type CheckpointV1 struct {
PreparedClaims PreparedClaims `json:"preparedClaims,omitempty"`
}
func newCheckpoint() *Checkpoint {
pc := &Checkpoint{
Checksum: 0,
V1: &CheckpointV1{
PreparedClaims: make(PreparedClaims),
},
}
return pc
}
func (cp *Checkpoint) MarshalCheckpoint() ([]byte, error) {
cp.Checksum = 0
out, err := json.Marshal(*cp)
if err != nil {
return nil, err
}
cp.Checksum = checksum.New(out)
return json.Marshal(*cp)
}
func (cp *Checkpoint) UnmarshalCheckpoint(data []byte) error {
return json.Unmarshal(data, cp)
}
func (cp *Checkpoint) VerifyChecksum() error {
ck := cp.Checksum
cp.Checksum = 0
defer func() {
cp.Checksum = ck
}()
out, err := json.Marshal(*cp)
if err != nil {
return err
}
return ck.Verify(out)
}

View File

@@ -0,0 +1,86 @@
/*
* Copyright 2023 The Kubernetes Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"fmt"
"math/rand"
"os"
resourceapi "k8s.io/api/resource/v1beta1"
"k8s.io/apimachinery/pkg/api/resource"
"k8s.io/utils/ptr"
"github.com/google/uuid"
)
func enumerateAllPossibleDevices(numGPUs int) (AllocatableDevices, error) {
seed := os.Getenv("NODE_NAME")
uuids := generateUUIDs(seed, numGPUs)
alldevices := make(AllocatableDevices)
for i, uuid := range uuids {
device := resourceapi.Device{
Name: fmt.Sprintf("gpu-%d", i),
Basic: &resourceapi.BasicDevice{
Attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{
"index": {
IntValue: ptr.To(int64(i)),
},
"uuid": {
StringValue: ptr.To(uuid),
},
"model": {
StringValue: ptr.To("LATEST-GPU-MODEL"),
},
"driverVersion": {
VersionValue: ptr.To("1.0.0"),
},
},
Capacity: map[resourceapi.QualifiedName]resourceapi.DeviceCapacity{
"memory": {
Value: resource.MustParse("80Gi"),
},
},
},
}
alldevices[device.Name] = device
}
return alldevices, nil
}
func generateUUIDs(seed string, count int) []string {
rand := rand.New(rand.NewSource(hash(seed)))
uuids := make([]string, count)
for i := 0; i < count; i++ {
charset := make([]byte, 16)
rand.Read(charset)
uuid, _ := uuid.FromBytes(charset)
uuids[i] = "gpu-" + uuid.String()
}
return uuids
}
func hash(s string) int64 {
h := int64(0)
for _, c := range s {
h = 31*h + int64(c)
}
return h
}

View File

@@ -0,0 +1,135 @@
/*
* Copyright 2023 The Kubernetes Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"context"
"fmt"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
coreclientset "k8s.io/client-go/kubernetes"
"k8s.io/dynamic-resource-allocation/kubeletplugin"
"k8s.io/klog/v2"
drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1"
"sigs.k8s.io/dra-example-driver/pkg/consts"
)
var _ drapbv1.DRAPluginServer = &driver{}
type driver struct {
client coreclientset.Interface
plugin kubeletplugin.DRAPlugin
state *DeviceState
}
func NewDriver(ctx context.Context, config *Config) (*driver, error) {
driver := &driver{
client: config.coreclient,
}
state, err := NewDeviceState(config)
if err != nil {
return nil, err
}
driver.state = state
plugin, err := kubeletplugin.Start(
ctx,
[]any{driver},
kubeletplugin.KubeClient(config.coreclient),
kubeletplugin.NodeName(config.flags.nodeName),
kubeletplugin.DriverName(consts.DriverName),
kubeletplugin.RegistrarSocketPath(PluginRegistrationPath),
kubeletplugin.PluginSocketPath(DriverPluginSocketPath),
kubeletplugin.KubeletPluginSocketPath(DriverPluginSocketPath))
if err != nil {
return nil, err
}
driver.plugin = plugin
var resources kubeletplugin.Resources
for _, device := range state.allocatable {
resources.Devices = append(resources.Devices, device)
}
if err := plugin.PublishResources(ctx, resources); err != nil {
return nil, err
}
return driver, nil
}
func (d *driver) Shutdown(ctx context.Context) error {
d.plugin.Stop()
return nil
}
func (d *driver) NodePrepareResources(ctx context.Context, req *drapbv1.NodePrepareResourcesRequest) (*drapbv1.NodePrepareResourcesResponse, error) {
klog.Infof("NodePrepareResource is called: number of claims: %d", len(req.Claims))
preparedResources := &drapbv1.NodePrepareResourcesResponse{Claims: map[string]*drapbv1.NodePrepareResourceResponse{}}
for _, claim := range req.Claims {
preparedResources.Claims[claim.UID] = d.nodePrepareResource(ctx, claim)
}
return preparedResources, nil
}
func (d *driver) nodePrepareResource(ctx context.Context, claim *drapbv1.Claim) *drapbv1.NodePrepareResourceResponse {
resourceClaim, err := d.client.ResourceV1beta1().ResourceClaims(claim.Namespace).Get(
ctx,
claim.Name,
metav1.GetOptions{})
if err != nil {
return &drapbv1.NodePrepareResourceResponse{
Error: fmt.Sprintf("failed to fetch ResourceClaim %s in namespace %s", claim.Name, claim.Namespace),
}
}
prepared, err := d.state.Prepare(resourceClaim)
if err != nil {
return &drapbv1.NodePrepareResourceResponse{
Error: fmt.Sprintf("error preparing devices for claim %v: %v", claim.UID, err),
}
}
klog.Infof("Returning newly prepared devices for claim '%v': %v", claim.UID, prepared)
return &drapbv1.NodePrepareResourceResponse{Devices: prepared}
}
func (d *driver) NodeUnprepareResources(ctx context.Context, req *drapbv1.NodeUnprepareResourcesRequest) (*drapbv1.NodeUnprepareResourcesResponse, error) {
klog.Infof("NodeUnPrepareResource is called: number of claims: %d", len(req.Claims))
unpreparedResources := &drapbv1.NodeUnprepareResourcesResponse{Claims: map[string]*drapbv1.NodeUnprepareResourceResponse{}}
for _, claim := range req.Claims {
unpreparedResources.Claims[claim.UID] = d.nodeUnprepareResource(ctx, claim)
}
return unpreparedResources, nil
}
func (d *driver) nodeUnprepareResource(ctx context.Context, claim *drapbv1.Claim) *drapbv1.NodeUnprepareResourceResponse {
if err := d.state.Unprepare(claim.UID); err != nil {
return &drapbv1.NodeUnprepareResourceResponse{
Error: fmt.Sprintf("error unpreparing devices for claim %v: %v", claim.UID, err),
}
}
return &drapbv1.NodeUnprepareResourceResponse{}
}

View File

@@ -0,0 +1,158 @@
/*
* Copyright 2023 The Kubernetes Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"context"
"fmt"
"os"
"os/signal"
"syscall"
"github.com/urfave/cli/v2"
coreclientset "k8s.io/client-go/kubernetes"
"k8s.io/klog/v2"
"sigs.k8s.io/dra-example-driver/pkg/consts"
"sigs.k8s.io/dra-example-driver/pkg/flags"
)
const (
PluginRegistrationPath = "/var/lib/kubelet/plugins_registry/" + consts.DriverName + ".sock"
DriverPluginPath = "/var/lib/kubelet/plugins/" + consts.DriverName
DriverPluginSocketPath = DriverPluginPath + "/plugin.sock"
DriverPluginCheckpointFile = "checkpoint.json"
)
type Flags struct {
kubeClientConfig flags.KubeClientConfig
loggingConfig *flags.LoggingConfig
nodeName string
cdiRoot string
numDevices int
}
type Config struct {
flags *Flags
coreclient coreclientset.Interface
}
func main() {
if err := newApp().Run(os.Args); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
}
func newApp() *cli.App {
flags := &Flags{
loggingConfig: flags.NewLoggingConfig(),
}
cliFlags := []cli.Flag{
&cli.StringFlag{
Name: "node-name",
Usage: "The name of the node to be worked on.",
Required: true,
Destination: &flags.nodeName,
EnvVars: []string{"NODE_NAME"},
},
&cli.StringFlag{
Name: "cdi-root",
Usage: "Absolute path to the directory where CDI files will be generated.",
Value: "/etc/cdi",
Destination: &flags.cdiRoot,
EnvVars: []string{"CDI_ROOT"},
},
&cli.IntFlag{
Name: "num-devices",
Usage: "The number of devices to be generated.",
Value: 8,
Destination: &flags.numDevices,
EnvVars: []string{"NUM_DEVICES"},
},
}
cliFlags = append(cliFlags, flags.kubeClientConfig.Flags()...)
cliFlags = append(cliFlags, flags.loggingConfig.Flags()...)
app := &cli.App{
Name: "dra-example-kubeletplugin",
Usage: "dra-example-kubeletplugin implements a DRA driver plugin.",
ArgsUsage: " ",
HideHelpCommand: true,
Flags: cliFlags,
Before: func(c *cli.Context) error {
if c.Args().Len() > 0 {
return fmt.Errorf("arguments not supported: %v", c.Args().Slice())
}
return flags.loggingConfig.Apply()
},
Action: func(c *cli.Context) error {
ctx := c.Context
clientSets, err := flags.kubeClientConfig.NewClientSets()
if err != nil {
return fmt.Errorf("create client: %v", err)
}
config := &Config{
flags: flags,
coreclient: clientSets.Core,
}
return StartPlugin(ctx, config)
},
}
return app
}
func StartPlugin(ctx context.Context, config *Config) error {
err := os.MkdirAll(DriverPluginPath, 0750)
if err != nil {
return err
}
info, err := os.Stat(config.flags.cdiRoot)
switch {
case err != nil && os.IsNotExist(err):
err := os.MkdirAll(config.flags.cdiRoot, 0750)
if err != nil {
return err
}
case err != nil:
return err
case !info.IsDir():
return fmt.Errorf("path for cdi file generation is not a directory: '%v'", err)
}
driver, err := NewDriver(ctx, config)
if err != nil {
return err
}
sigc := make(chan os.Signal, 1)
signal.Notify(sigc, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
<-sigc
err = driver.Shutdown(ctx)
if err != nil {
klog.FromContext(ctx).Error(err, "Unable to cleanly shutdown driver")
}
return nil
}

View File

@@ -0,0 +1,382 @@
/*
* Copyright 2023 The Kubernetes Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"fmt"
"slices"
"sync"
resourceapi "k8s.io/api/resource/v1beta1"
"k8s.io/apimachinery/pkg/runtime"
drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1"
"k8s.io/kubernetes/pkg/kubelet/checkpointmanager"
configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1"
"sigs.k8s.io/dra-example-driver/pkg/consts"
cdiapi "tags.cncf.io/container-device-interface/pkg/cdi"
cdispec "tags.cncf.io/container-device-interface/specs-go"
)
type AllocatableDevices map[string]resourceapi.Device
type PreparedDevices []*PreparedDevice
type PreparedClaims map[string]PreparedDevices
type PerDeviceCDIContainerEdits map[string]*cdiapi.ContainerEdits
type OpaqueDeviceConfig struct {
Requests []string
Config runtime.Object
}
type PreparedDevice struct {
drapbv1.Device
ContainerEdits *cdiapi.ContainerEdits
}
func (pds PreparedDevices) GetDevices() []*drapbv1.Device {
var devices []*drapbv1.Device
for _, pd := range pds {
devices = append(devices, &pd.Device)
}
return devices
}
type DeviceState struct {
sync.Mutex
cdi *CDIHandler
allocatable AllocatableDevices
checkpointManager checkpointmanager.CheckpointManager
}
func NewDeviceState(config *Config) (*DeviceState, error) {
allocatable, err := enumerateAllPossibleDevices(config.flags.numDevices)
if err != nil {
return nil, fmt.Errorf("error enumerating all possible devices: %v", err)
}
cdi, err := NewCDIHandler(config)
if err != nil {
return nil, fmt.Errorf("unable to create CDI handler: %v", err)
}
err = cdi.CreateCommonSpecFile()
if err != nil {
return nil, fmt.Errorf("unable to create CDI spec file for common edits: %v", err)
}
checkpointManager, err := checkpointmanager.NewCheckpointManager(DriverPluginPath)
if err != nil {
return nil, fmt.Errorf("unable to create checkpoint manager: %v", err)
}
state := &DeviceState{
cdi: cdi,
allocatable: allocatable,
checkpointManager: checkpointManager,
}
checkpoints, err := state.checkpointManager.ListCheckpoints()
if err != nil {
return nil, fmt.Errorf("unable to list checkpoints: %v", err)
}
for _, c := range checkpoints {
if c == DriverPluginCheckpointFile {
return state, nil
}
}
checkpoint := newCheckpoint()
if err := state.checkpointManager.CreateCheckpoint(DriverPluginCheckpointFile, checkpoint); err != nil {
return nil, fmt.Errorf("unable to sync to checkpoint: %v", err)
}
return state, nil
}
func (s *DeviceState) Prepare(claim *resourceapi.ResourceClaim) ([]*drapbv1.Device, error) {
s.Lock()
defer s.Unlock()
claimUID := string(claim.UID)
checkpoint := newCheckpoint()
if err := s.checkpointManager.GetCheckpoint(DriverPluginCheckpointFile, checkpoint); err != nil {
return nil, fmt.Errorf("unable to sync from checkpoint: %v", err)
}
preparedClaims := checkpoint.V1.PreparedClaims
if preparedClaims[claimUID] != nil {
return preparedClaims[claimUID].GetDevices(), nil
}
preparedDevices, err := s.prepareDevices(claim)
if err != nil {
return nil, fmt.Errorf("prepare failed: %v", err)
}
if err = s.cdi.CreateClaimSpecFile(claimUID, preparedDevices); err != nil {
return nil, fmt.Errorf("unable to create CDI spec file for claim: %v", err)
}
preparedClaims[claimUID] = preparedDevices
if err := s.checkpointManager.CreateCheckpoint(DriverPluginCheckpointFile, checkpoint); err != nil {
return nil, fmt.Errorf("unable to sync to checkpoint: %v", err)
}
return preparedClaims[claimUID].GetDevices(), nil
}
func (s *DeviceState) Unprepare(claimUID string) error {
s.Lock()
defer s.Unlock()
checkpoint := newCheckpoint()
if err := s.checkpointManager.GetCheckpoint(DriverPluginCheckpointFile, checkpoint); err != nil {
return fmt.Errorf("unable to sync from checkpoint: %v", err)
}
preparedClaims := checkpoint.V1.PreparedClaims
if preparedClaims[claimUID] == nil {
return nil
}
if err := s.unprepareDevices(claimUID, preparedClaims[claimUID]); err != nil {
return fmt.Errorf("unprepare failed: %v", err)
}
err := s.cdi.DeleteClaimSpecFile(claimUID)
if err != nil {
return fmt.Errorf("unable to delete CDI spec file for claim: %v", err)
}
delete(preparedClaims, claimUID)
if err := s.checkpointManager.CreateCheckpoint(DriverPluginCheckpointFile, checkpoint); err != nil {
return fmt.Errorf("unable to sync to checkpoint: %v", err)
}
return nil
}
func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (PreparedDevices, error) {
if claim.Status.Allocation == nil {
return nil, fmt.Errorf("claim not yet allocated")
}
// Retrieve the full set of device configs for the driver.
configs, err := GetOpaqueDeviceConfigs(
configapi.Decoder,
consts.DriverName,
claim.Status.Allocation.Devices.Config,
)
if err != nil {
return nil, fmt.Errorf("error getting opaque device configs: %v", err)
}
// Add the default GPU Config to the front of the config list with the
// lowest precedence. This guarantees there will be at least one config in
// the list with len(Requests) == 0 for the lookup below.
configs = slices.Insert(configs, 0, &OpaqueDeviceConfig{
Requests: []string{},
Config: configapi.DefaultGpuConfig(),
})
// Look through the configs and figure out which one will be applied to
// each device allocation result based on their order of precedence.
configResultsMap := make(map[runtime.Object][]*resourceapi.DeviceRequestAllocationResult)
for _, result := range claim.Status.Allocation.Devices.Results {
if _, exists := s.allocatable[result.Device]; !exists {
return nil, fmt.Errorf("requested GPU is not allocatable: %v", result.Device)
}
for _, c := range slices.Backward(configs) {
if len(c.Requests) == 0 || slices.Contains(c.Requests, result.Request) {
configResultsMap[c.Config] = append(configResultsMap[c.Config], &result)
break
}
}
}
// Normalize, validate, and apply all configs associated with devices that
// need to be prepared. Track container edits generated from applying the
// config to the set of device allocation results.
perDeviceCDIContainerEdits := make(PerDeviceCDIContainerEdits)
for c, results := range configResultsMap {
// Cast the opaque config to a GpuConfig
var config *configapi.GpuConfig
switch castConfig := c.(type) {
case *configapi.GpuConfig:
config = castConfig
default:
return nil, fmt.Errorf("runtime object is not a regognized configuration")
}
// Normalize the config to set any implied defaults.
if err := config.Normalize(); err != nil {
return nil, fmt.Errorf("error normalizing GPU config: %w", err)
}
// Validate the config to ensure its integrity.
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("error validating GPU config: %w", err)
}
// Apply the config to the list of results associated with it.
containerEdits, err := s.applyConfig(config, results)
if err != nil {
return nil, fmt.Errorf("error applying GPU config: %w", err)
}
// Merge any new container edits with the overall per device map.
for k, v := range containerEdits {
perDeviceCDIContainerEdits[k] = v
}
}
// Walk through each config and its associated device allocation results
// and construct the list of prepared devices to return.
var preparedDevices PreparedDevices
for _, results := range configResultsMap {
for _, result := range results {
device := &PreparedDevice{
Device: drapbv1.Device{
RequestNames: []string{result.Request},
PoolName: result.Pool,
DeviceName: result.Device,
CDIDeviceIDs: s.cdi.GetClaimDevices(string(claim.UID), []string{result.Device}),
},
ContainerEdits: perDeviceCDIContainerEdits[result.Device],
}
preparedDevices = append(preparedDevices, device)
}
}
return preparedDevices, nil
}
func (s *DeviceState) unprepareDevices(claimUID string, devices PreparedDevices) error {
return nil
}
// applyConfig applies a configuration to a set of device allocation results.
//
// In this example driver there is no actual configuration applied. We simply
// define a set of environment variables to be injected into the containers
// that include a given device. A real driver would likely need to do some sort
// of hardware configuration as well, based on the config passed in.
func (s *DeviceState) applyConfig(config *configapi.GpuConfig, results []*resourceapi.DeviceRequestAllocationResult) (PerDeviceCDIContainerEdits, error) {
perDeviceEdits := make(PerDeviceCDIContainerEdits)
for _, result := range results {
envs := []string{
fmt.Sprintf("GPU_DEVICE_%s=%s", result.Device[4:], result.Device),
}
if config.Sharing != nil {
envs = append(envs, fmt.Sprintf("GPU_DEVICE_%s_SHARING_STRATEGY=%s", result.Device[4:], config.Sharing.Strategy))
}
switch {
case config.Sharing.IsTimeSlicing():
tsconfig, err := config.Sharing.GetTimeSlicingConfig()
if err != nil {
return nil, fmt.Errorf("unable to get time slicing config for device %v: %w", result.Device, err)
}
envs = append(envs, fmt.Sprintf("GPU_DEVICE_%s_TIMESLICE_INTERVAL=%v", result.Device[4:], tsconfig.Interval))
case config.Sharing.IsSpacePartitioning():
spconfig, err := config.Sharing.GetSpacePartitioningConfig()
if err != nil {
return nil, fmt.Errorf("unable to get space partitioning config for device %v: %w", result.Device, err)
}
envs = append(envs, fmt.Sprintf("GPU_DEVICE_%s_PARTITION_COUNT=%v", result.Device[4:], spconfig.PartitionCount))
}
edits := &cdispec.ContainerEdits{
Env: envs,
}
perDeviceEdits[result.Device] = &cdiapi.ContainerEdits{ContainerEdits: edits}
}
return perDeviceEdits, nil
}
// GetOpaqueDeviceConfigs returns an ordered list of the configs contained in possibleConfigs for this driver.
//
// Configs can either come from the resource claim itself or from the device
// class associated with the request. Configs coming directly from the resource
// claim take precedence over configs coming from the device class. Moreover,
// configs found later in the list of configs attached to its source take
// precedence over configs found earlier in the list for that source.
//
// All of the configs relevant to the driver from the list of possibleConfigs
// will be returned in order of precedence (from lowest to highest). If no
// configs are found, nil is returned.
func GetOpaqueDeviceConfigs(
decoder runtime.Decoder,
driverName string,
possibleConfigs []resourceapi.DeviceAllocationConfiguration,
) ([]*OpaqueDeviceConfig, error) {
// Collect all configs in order of reverse precedence.
var classConfigs []resourceapi.DeviceAllocationConfiguration
var claimConfigs []resourceapi.DeviceAllocationConfiguration
var candidateConfigs []resourceapi.DeviceAllocationConfiguration
for _, config := range possibleConfigs {
switch config.Source {
case resourceapi.AllocationConfigSourceClass:
classConfigs = append(classConfigs, config)
case resourceapi.AllocationConfigSourceClaim:
claimConfigs = append(claimConfigs, config)
default:
return nil, fmt.Errorf("invalid config source: %v", config.Source)
}
}
candidateConfigs = append(candidateConfigs, classConfigs...)
candidateConfigs = append(candidateConfigs, claimConfigs...)
// Decode all configs that are relevant for the driver.
var resultConfigs []*OpaqueDeviceConfig
for _, config := range candidateConfigs {
// If this is nil, the driver doesn't support some future API extension
// and needs to be updated.
if config.DeviceConfiguration.Opaque == nil {
return nil, fmt.Errorf("only opaque parameters are supported by this driver")
}
// Configs for different drivers may have been specified because a
// single request can be satisfied by different drivers. This is not
// an error -- drivers must skip over other driver's configs in order
// to support this.
if config.DeviceConfiguration.Opaque.Driver != driverName {
continue
}
decodedConfig, err := runtime.Decode(decoder, config.DeviceConfiguration.Opaque.Parameters.Raw)
if err != nil {
return nil, fmt.Errorf("error decoding config parameters: %w", err)
}
resultConfig := &OpaqueDeviceConfig{
Requests: config.Requests,
Config: decodedConfig,
}
resultConfigs = append(resultConfigs, resultConfig)
}
return resultConfigs, nil
}

View File

@@ -0,0 +1,56 @@
/*
* Copyright 2025 The Kubernetes Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package main
import (
"testing"
"github.com/stretchr/testify/assert"
drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1"
)
func TestPreparedDevicesGetDevices(t *testing.T) {
tests := map[string]struct {
preparedDevices PreparedDevices
expected []*drapbv1.Device
}{
"nil PreparedDevices": {
preparedDevices: nil,
expected: nil,
},
"several PreparedDevices": {
preparedDevices: PreparedDevices{
{Device: drapbv1.Device{DeviceName: "dev1"}},
{Device: drapbv1.Device{DeviceName: "dev2"}},
{Device: drapbv1.Device{DeviceName: "dev3"}},
},
expected: []*drapbv1.Device{
{DeviceName: "dev1"},
{DeviceName: "dev2"},
{DeviceName: "dev3"},
},
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
devices := test.preparedDevices.GetDevices()
assert.Equal(t, test.expected, devices)
})
}
}