mirror of
https://github.com/k8snetworkplumbingwg/multus-cni.git
synced 2025-09-18 07:28:50 +00:00
initial copy from DRA example driver
This commit is contained in:
137
cmd/dra-multus-driver/cdi.go
Normal file
137
cmd/dra-multus-driver/cdi.go
Normal file
@@ -0,0 +1,137 @@
|
||||
/*
|
||||
* Copyright 2023 The Kubernetes Authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"sigs.k8s.io/dra-example-driver/pkg/consts"
|
||||
|
||||
cdiapi "tags.cncf.io/container-device-interface/pkg/cdi"
|
||||
cdiparser "tags.cncf.io/container-device-interface/pkg/parser"
|
||||
cdispec "tags.cncf.io/container-device-interface/specs-go"
|
||||
)
|
||||
|
||||
const (
|
||||
cdiVendor = "k8s." + consts.DriverName
|
||||
cdiClass = "gpu"
|
||||
cdiKind = cdiVendor + "/" + cdiClass
|
||||
|
||||
cdiCommonDeviceName = "common"
|
||||
)
|
||||
|
||||
type CDIHandler struct {
|
||||
cache *cdiapi.Cache
|
||||
}
|
||||
|
||||
func NewCDIHandler(config *Config) (*CDIHandler, error) {
|
||||
cache, err := cdiapi.NewCache(
|
||||
cdiapi.WithSpecDirs(config.flags.cdiRoot),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create a new CDI cache: %w", err)
|
||||
}
|
||||
handler := &CDIHandler{
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
return handler, nil
|
||||
}
|
||||
|
||||
func (cdi *CDIHandler) CreateCommonSpecFile() error {
|
||||
spec := &cdispec.Spec{
|
||||
Kind: cdiKind,
|
||||
Devices: []cdispec.Device{
|
||||
{
|
||||
Name: cdiCommonDeviceName,
|
||||
ContainerEdits: cdispec.ContainerEdits{
|
||||
Env: []string{
|
||||
fmt.Sprintf("KUBERNETES_NODE_NAME=%s", os.Getenv("NODE_NAME")),
|
||||
fmt.Sprintf("DRA_RESOURCE_DRIVER_NAME=%s", consts.DriverName),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
minVersion, err := cdiapi.MinimumRequiredVersion(spec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get minimum required CDI spec version: %v", err)
|
||||
}
|
||||
spec.Version = minVersion
|
||||
|
||||
specName, err := cdiapi.GenerateNameForTransientSpec(spec, cdiCommonDeviceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate Spec name: %w", err)
|
||||
}
|
||||
|
||||
return cdi.cache.WriteSpec(spec, specName)
|
||||
}
|
||||
|
||||
func (cdi *CDIHandler) CreateClaimSpecFile(claimUID string, devices PreparedDevices) error {
|
||||
specName := cdiapi.GenerateTransientSpecName(cdiVendor, cdiClass, claimUID)
|
||||
|
||||
spec := &cdispec.Spec{
|
||||
Kind: cdiKind,
|
||||
Devices: []cdispec.Device{},
|
||||
}
|
||||
|
||||
for _, device := range devices {
|
||||
claimEdits := cdiapi.ContainerEdits{
|
||||
ContainerEdits: &cdispec.ContainerEdits{
|
||||
Env: []string{
|
||||
fmt.Sprintf("GPU_DEVICE_%s_RESOURCE_CLAIM=%s", device.DeviceName[4:], claimUID),
|
||||
},
|
||||
},
|
||||
}
|
||||
claimEdits.Append(device.ContainerEdits)
|
||||
|
||||
cdiDevice := cdispec.Device{
|
||||
Name: fmt.Sprintf("%s-%s", claimUID, device.DeviceName),
|
||||
ContainerEdits: *claimEdits.ContainerEdits,
|
||||
}
|
||||
|
||||
spec.Devices = append(spec.Devices, cdiDevice)
|
||||
}
|
||||
|
||||
minVersion, err := cdiapi.MinimumRequiredVersion(spec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get minimum required CDI spec version: %v", err)
|
||||
}
|
||||
spec.Version = minVersion
|
||||
|
||||
return cdi.cache.WriteSpec(spec, specName)
|
||||
}
|
||||
|
||||
func (cdi *CDIHandler) DeleteClaimSpecFile(claimUID string) error {
|
||||
specName := cdiapi.GenerateTransientSpecName(cdiVendor, cdiClass, claimUID)
|
||||
return cdi.cache.RemoveSpec(specName)
|
||||
}
|
||||
|
||||
func (cdi *CDIHandler) GetClaimDevices(claimUID string, devices []string) []string {
|
||||
cdiDevices := []string{
|
||||
cdiparser.QualifiedName(cdiVendor, cdiClass, cdiCommonDeviceName),
|
||||
}
|
||||
|
||||
for _, device := range devices {
|
||||
cdiDevice := cdiparser.QualifiedName(cdiVendor, cdiClass, fmt.Sprintf("%s-%s", claimUID, device))
|
||||
cdiDevices = append(cdiDevices, cdiDevice)
|
||||
}
|
||||
|
||||
return cdiDevices
|
||||
}
|
53
cmd/dra-multus-driver/checkpoint.go
Normal file
53
cmd/dra-multus-driver/checkpoint.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"k8s.io/kubernetes/pkg/kubelet/checkpointmanager/checksum"
|
||||
)
|
||||
|
||||
type Checkpoint struct {
|
||||
Checksum checksum.Checksum `json:"checksum"`
|
||||
V1 *CheckpointV1 `json:"v1,omitempty"`
|
||||
}
|
||||
|
||||
type CheckpointV1 struct {
|
||||
PreparedClaims PreparedClaims `json:"preparedClaims,omitempty"`
|
||||
}
|
||||
|
||||
func newCheckpoint() *Checkpoint {
|
||||
pc := &Checkpoint{
|
||||
Checksum: 0,
|
||||
V1: &CheckpointV1{
|
||||
PreparedClaims: make(PreparedClaims),
|
||||
},
|
||||
}
|
||||
return pc
|
||||
}
|
||||
|
||||
func (cp *Checkpoint) MarshalCheckpoint() ([]byte, error) {
|
||||
cp.Checksum = 0
|
||||
out, err := json.Marshal(*cp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cp.Checksum = checksum.New(out)
|
||||
return json.Marshal(*cp)
|
||||
}
|
||||
|
||||
func (cp *Checkpoint) UnmarshalCheckpoint(data []byte) error {
|
||||
return json.Unmarshal(data, cp)
|
||||
}
|
||||
|
||||
func (cp *Checkpoint) VerifyChecksum() error {
|
||||
ck := cp.Checksum
|
||||
cp.Checksum = 0
|
||||
defer func() {
|
||||
cp.Checksum = ck
|
||||
}()
|
||||
out, err := json.Marshal(*cp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ck.Verify(out)
|
||||
}
|
86
cmd/dra-multus-driver/discovery.go
Normal file
86
cmd/dra-multus-driver/discovery.go
Normal file
@@ -0,0 +1,86 @@
|
||||
/*
|
||||
* Copyright 2023 The Kubernetes Authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
|
||||
resourceapi "k8s.io/api/resource/v1beta1"
|
||||
"k8s.io/apimachinery/pkg/api/resource"
|
||||
"k8s.io/utils/ptr"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func enumerateAllPossibleDevices(numGPUs int) (AllocatableDevices, error) {
|
||||
seed := os.Getenv("NODE_NAME")
|
||||
uuids := generateUUIDs(seed, numGPUs)
|
||||
|
||||
alldevices := make(AllocatableDevices)
|
||||
for i, uuid := range uuids {
|
||||
device := resourceapi.Device{
|
||||
Name: fmt.Sprintf("gpu-%d", i),
|
||||
Basic: &resourceapi.BasicDevice{
|
||||
Attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{
|
||||
"index": {
|
||||
IntValue: ptr.To(int64(i)),
|
||||
},
|
||||
"uuid": {
|
||||
StringValue: ptr.To(uuid),
|
||||
},
|
||||
"model": {
|
||||
StringValue: ptr.To("LATEST-GPU-MODEL"),
|
||||
},
|
||||
"driverVersion": {
|
||||
VersionValue: ptr.To("1.0.0"),
|
||||
},
|
||||
},
|
||||
Capacity: map[resourceapi.QualifiedName]resourceapi.DeviceCapacity{
|
||||
"memory": {
|
||||
Value: resource.MustParse("80Gi"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
alldevices[device.Name] = device
|
||||
}
|
||||
return alldevices, nil
|
||||
}
|
||||
|
||||
func generateUUIDs(seed string, count int) []string {
|
||||
rand := rand.New(rand.NewSource(hash(seed)))
|
||||
|
||||
uuids := make([]string, count)
|
||||
for i := 0; i < count; i++ {
|
||||
charset := make([]byte, 16)
|
||||
rand.Read(charset)
|
||||
uuid, _ := uuid.FromBytes(charset)
|
||||
uuids[i] = "gpu-" + uuid.String()
|
||||
}
|
||||
|
||||
return uuids
|
||||
}
|
||||
|
||||
func hash(s string) int64 {
|
||||
h := int64(0)
|
||||
for _, c := range s {
|
||||
h = 31*h + int64(c)
|
||||
}
|
||||
return h
|
||||
}
|
135
cmd/dra-multus-driver/driver.go
Normal file
135
cmd/dra-multus-driver/driver.go
Normal file
@@ -0,0 +1,135 @@
|
||||
/*
|
||||
* Copyright 2023 The Kubernetes Authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
coreclientset "k8s.io/client-go/kubernetes"
|
||||
"k8s.io/dynamic-resource-allocation/kubeletplugin"
|
||||
"k8s.io/klog/v2"
|
||||
|
||||
drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1"
|
||||
|
||||
"sigs.k8s.io/dra-example-driver/pkg/consts"
|
||||
)
|
||||
|
||||
var _ drapbv1.DRAPluginServer = &driver{}
|
||||
|
||||
type driver struct {
|
||||
client coreclientset.Interface
|
||||
plugin kubeletplugin.DRAPlugin
|
||||
state *DeviceState
|
||||
}
|
||||
|
||||
func NewDriver(ctx context.Context, config *Config) (*driver, error) {
|
||||
driver := &driver{
|
||||
client: config.coreclient,
|
||||
}
|
||||
|
||||
state, err := NewDeviceState(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
driver.state = state
|
||||
|
||||
plugin, err := kubeletplugin.Start(
|
||||
ctx,
|
||||
[]any{driver},
|
||||
kubeletplugin.KubeClient(config.coreclient),
|
||||
kubeletplugin.NodeName(config.flags.nodeName),
|
||||
kubeletplugin.DriverName(consts.DriverName),
|
||||
kubeletplugin.RegistrarSocketPath(PluginRegistrationPath),
|
||||
kubeletplugin.PluginSocketPath(DriverPluginSocketPath),
|
||||
kubeletplugin.KubeletPluginSocketPath(DriverPluginSocketPath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
driver.plugin = plugin
|
||||
|
||||
var resources kubeletplugin.Resources
|
||||
for _, device := range state.allocatable {
|
||||
resources.Devices = append(resources.Devices, device)
|
||||
}
|
||||
|
||||
if err := plugin.PublishResources(ctx, resources); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return driver, nil
|
||||
}
|
||||
|
||||
func (d *driver) Shutdown(ctx context.Context) error {
|
||||
d.plugin.Stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *driver) NodePrepareResources(ctx context.Context, req *drapbv1.NodePrepareResourcesRequest) (*drapbv1.NodePrepareResourcesResponse, error) {
|
||||
klog.Infof("NodePrepareResource is called: number of claims: %d", len(req.Claims))
|
||||
preparedResources := &drapbv1.NodePrepareResourcesResponse{Claims: map[string]*drapbv1.NodePrepareResourceResponse{}}
|
||||
|
||||
for _, claim := range req.Claims {
|
||||
preparedResources.Claims[claim.UID] = d.nodePrepareResource(ctx, claim)
|
||||
}
|
||||
|
||||
return preparedResources, nil
|
||||
}
|
||||
|
||||
func (d *driver) nodePrepareResource(ctx context.Context, claim *drapbv1.Claim) *drapbv1.NodePrepareResourceResponse {
|
||||
resourceClaim, err := d.client.ResourceV1beta1().ResourceClaims(claim.Namespace).Get(
|
||||
ctx,
|
||||
claim.Name,
|
||||
metav1.GetOptions{})
|
||||
if err != nil {
|
||||
return &drapbv1.NodePrepareResourceResponse{
|
||||
Error: fmt.Sprintf("failed to fetch ResourceClaim %s in namespace %s", claim.Name, claim.Namespace),
|
||||
}
|
||||
}
|
||||
|
||||
prepared, err := d.state.Prepare(resourceClaim)
|
||||
if err != nil {
|
||||
return &drapbv1.NodePrepareResourceResponse{
|
||||
Error: fmt.Sprintf("error preparing devices for claim %v: %v", claim.UID, err),
|
||||
}
|
||||
}
|
||||
|
||||
klog.Infof("Returning newly prepared devices for claim '%v': %v", claim.UID, prepared)
|
||||
return &drapbv1.NodePrepareResourceResponse{Devices: prepared}
|
||||
}
|
||||
|
||||
func (d *driver) NodeUnprepareResources(ctx context.Context, req *drapbv1.NodeUnprepareResourcesRequest) (*drapbv1.NodeUnprepareResourcesResponse, error) {
|
||||
klog.Infof("NodeUnPrepareResource is called: number of claims: %d", len(req.Claims))
|
||||
unpreparedResources := &drapbv1.NodeUnprepareResourcesResponse{Claims: map[string]*drapbv1.NodeUnprepareResourceResponse{}}
|
||||
|
||||
for _, claim := range req.Claims {
|
||||
unpreparedResources.Claims[claim.UID] = d.nodeUnprepareResource(ctx, claim)
|
||||
}
|
||||
|
||||
return unpreparedResources, nil
|
||||
}
|
||||
|
||||
func (d *driver) nodeUnprepareResource(ctx context.Context, claim *drapbv1.Claim) *drapbv1.NodeUnprepareResourceResponse {
|
||||
if err := d.state.Unprepare(claim.UID); err != nil {
|
||||
return &drapbv1.NodeUnprepareResourceResponse{
|
||||
Error: fmt.Sprintf("error unpreparing devices for claim %v: %v", claim.UID, err),
|
||||
}
|
||||
}
|
||||
|
||||
return &drapbv1.NodeUnprepareResourceResponse{}
|
||||
}
|
158
cmd/dra-multus-driver/main.go
Normal file
158
cmd/dra-multus-driver/main.go
Normal file
@@ -0,0 +1,158 @@
|
||||
/*
|
||||
* Copyright 2023 The Kubernetes Authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
coreclientset "k8s.io/client-go/kubernetes"
|
||||
"k8s.io/klog/v2"
|
||||
|
||||
"sigs.k8s.io/dra-example-driver/pkg/consts"
|
||||
"sigs.k8s.io/dra-example-driver/pkg/flags"
|
||||
)
|
||||
|
||||
const (
|
||||
PluginRegistrationPath = "/var/lib/kubelet/plugins_registry/" + consts.DriverName + ".sock"
|
||||
DriverPluginPath = "/var/lib/kubelet/plugins/" + consts.DriverName
|
||||
DriverPluginSocketPath = DriverPluginPath + "/plugin.sock"
|
||||
DriverPluginCheckpointFile = "checkpoint.json"
|
||||
)
|
||||
|
||||
type Flags struct {
|
||||
kubeClientConfig flags.KubeClientConfig
|
||||
loggingConfig *flags.LoggingConfig
|
||||
|
||||
nodeName string
|
||||
cdiRoot string
|
||||
numDevices int
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
flags *Flags
|
||||
coreclient coreclientset.Interface
|
||||
}
|
||||
|
||||
func main() {
|
||||
if err := newApp().Run(os.Args); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func newApp() *cli.App {
|
||||
flags := &Flags{
|
||||
loggingConfig: flags.NewLoggingConfig(),
|
||||
}
|
||||
cliFlags := []cli.Flag{
|
||||
&cli.StringFlag{
|
||||
Name: "node-name",
|
||||
Usage: "The name of the node to be worked on.",
|
||||
Required: true,
|
||||
Destination: &flags.nodeName,
|
||||
EnvVars: []string{"NODE_NAME"},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "cdi-root",
|
||||
Usage: "Absolute path to the directory where CDI files will be generated.",
|
||||
Value: "/etc/cdi",
|
||||
Destination: &flags.cdiRoot,
|
||||
EnvVars: []string{"CDI_ROOT"},
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "num-devices",
|
||||
Usage: "The number of devices to be generated.",
|
||||
Value: 8,
|
||||
Destination: &flags.numDevices,
|
||||
EnvVars: []string{"NUM_DEVICES"},
|
||||
},
|
||||
}
|
||||
cliFlags = append(cliFlags, flags.kubeClientConfig.Flags()...)
|
||||
cliFlags = append(cliFlags, flags.loggingConfig.Flags()...)
|
||||
|
||||
app := &cli.App{
|
||||
Name: "dra-example-kubeletplugin",
|
||||
Usage: "dra-example-kubeletplugin implements a DRA driver plugin.",
|
||||
ArgsUsage: " ",
|
||||
HideHelpCommand: true,
|
||||
Flags: cliFlags,
|
||||
Before: func(c *cli.Context) error {
|
||||
if c.Args().Len() > 0 {
|
||||
return fmt.Errorf("arguments not supported: %v", c.Args().Slice())
|
||||
}
|
||||
return flags.loggingConfig.Apply()
|
||||
},
|
||||
Action: func(c *cli.Context) error {
|
||||
ctx := c.Context
|
||||
clientSets, err := flags.kubeClientConfig.NewClientSets()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create client: %v", err)
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
flags: flags,
|
||||
coreclient: clientSets.Core,
|
||||
}
|
||||
|
||||
return StartPlugin(ctx, config)
|
||||
},
|
||||
}
|
||||
|
||||
return app
|
||||
}
|
||||
|
||||
func StartPlugin(ctx context.Context, config *Config) error {
|
||||
err := os.MkdirAll(DriverPluginPath, 0750)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
info, err := os.Stat(config.flags.cdiRoot)
|
||||
switch {
|
||||
case err != nil && os.IsNotExist(err):
|
||||
err := os.MkdirAll(config.flags.cdiRoot, 0750)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case err != nil:
|
||||
return err
|
||||
case !info.IsDir():
|
||||
return fmt.Errorf("path for cdi file generation is not a directory: '%v'", err)
|
||||
}
|
||||
|
||||
driver, err := NewDriver(ctx, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sigc := make(chan os.Signal, 1)
|
||||
signal.Notify(sigc, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
|
||||
<-sigc
|
||||
|
||||
err = driver.Shutdown(ctx)
|
||||
if err != nil {
|
||||
klog.FromContext(ctx).Error(err, "Unable to cleanly shutdown driver")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
382
cmd/dra-multus-driver/state.go
Normal file
382
cmd/dra-multus-driver/state.go
Normal file
@@ -0,0 +1,382 @@
|
||||
/*
|
||||
* Copyright 2023 The Kubernetes Authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
resourceapi "k8s.io/api/resource/v1beta1"
|
||||
"k8s.io/apimachinery/pkg/runtime"
|
||||
drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1"
|
||||
"k8s.io/kubernetes/pkg/kubelet/checkpointmanager"
|
||||
|
||||
configapi "sigs.k8s.io/dra-example-driver/api/example.com/resource/gpu/v1alpha1"
|
||||
"sigs.k8s.io/dra-example-driver/pkg/consts"
|
||||
|
||||
cdiapi "tags.cncf.io/container-device-interface/pkg/cdi"
|
||||
cdispec "tags.cncf.io/container-device-interface/specs-go"
|
||||
)
|
||||
|
||||
type AllocatableDevices map[string]resourceapi.Device
|
||||
type PreparedDevices []*PreparedDevice
|
||||
type PreparedClaims map[string]PreparedDevices
|
||||
type PerDeviceCDIContainerEdits map[string]*cdiapi.ContainerEdits
|
||||
|
||||
type OpaqueDeviceConfig struct {
|
||||
Requests []string
|
||||
Config runtime.Object
|
||||
}
|
||||
|
||||
type PreparedDevice struct {
|
||||
drapbv1.Device
|
||||
ContainerEdits *cdiapi.ContainerEdits
|
||||
}
|
||||
|
||||
func (pds PreparedDevices) GetDevices() []*drapbv1.Device {
|
||||
var devices []*drapbv1.Device
|
||||
for _, pd := range pds {
|
||||
devices = append(devices, &pd.Device)
|
||||
}
|
||||
return devices
|
||||
}
|
||||
|
||||
type DeviceState struct {
|
||||
sync.Mutex
|
||||
cdi *CDIHandler
|
||||
allocatable AllocatableDevices
|
||||
checkpointManager checkpointmanager.CheckpointManager
|
||||
}
|
||||
|
||||
func NewDeviceState(config *Config) (*DeviceState, error) {
|
||||
allocatable, err := enumerateAllPossibleDevices(config.flags.numDevices)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error enumerating all possible devices: %v", err)
|
||||
}
|
||||
|
||||
cdi, err := NewCDIHandler(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create CDI handler: %v", err)
|
||||
}
|
||||
|
||||
err = cdi.CreateCommonSpecFile()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create CDI spec file for common edits: %v", err)
|
||||
}
|
||||
|
||||
checkpointManager, err := checkpointmanager.NewCheckpointManager(DriverPluginPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create checkpoint manager: %v", err)
|
||||
}
|
||||
|
||||
state := &DeviceState{
|
||||
cdi: cdi,
|
||||
allocatable: allocatable,
|
||||
checkpointManager: checkpointManager,
|
||||
}
|
||||
|
||||
checkpoints, err := state.checkpointManager.ListCheckpoints()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to list checkpoints: %v", err)
|
||||
}
|
||||
|
||||
for _, c := range checkpoints {
|
||||
if c == DriverPluginCheckpointFile {
|
||||
return state, nil
|
||||
}
|
||||
}
|
||||
|
||||
checkpoint := newCheckpoint()
|
||||
if err := state.checkpointManager.CreateCheckpoint(DriverPluginCheckpointFile, checkpoint); err != nil {
|
||||
return nil, fmt.Errorf("unable to sync to checkpoint: %v", err)
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (s *DeviceState) Prepare(claim *resourceapi.ResourceClaim) ([]*drapbv1.Device, error) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
claimUID := string(claim.UID)
|
||||
|
||||
checkpoint := newCheckpoint()
|
||||
if err := s.checkpointManager.GetCheckpoint(DriverPluginCheckpointFile, checkpoint); err != nil {
|
||||
return nil, fmt.Errorf("unable to sync from checkpoint: %v", err)
|
||||
}
|
||||
preparedClaims := checkpoint.V1.PreparedClaims
|
||||
|
||||
if preparedClaims[claimUID] != nil {
|
||||
return preparedClaims[claimUID].GetDevices(), nil
|
||||
}
|
||||
|
||||
preparedDevices, err := s.prepareDevices(claim)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("prepare failed: %v", err)
|
||||
}
|
||||
|
||||
if err = s.cdi.CreateClaimSpecFile(claimUID, preparedDevices); err != nil {
|
||||
return nil, fmt.Errorf("unable to create CDI spec file for claim: %v", err)
|
||||
}
|
||||
|
||||
preparedClaims[claimUID] = preparedDevices
|
||||
if err := s.checkpointManager.CreateCheckpoint(DriverPluginCheckpointFile, checkpoint); err != nil {
|
||||
return nil, fmt.Errorf("unable to sync to checkpoint: %v", err)
|
||||
}
|
||||
|
||||
return preparedClaims[claimUID].GetDevices(), nil
|
||||
}
|
||||
|
||||
func (s *DeviceState) Unprepare(claimUID string) error {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
checkpoint := newCheckpoint()
|
||||
if err := s.checkpointManager.GetCheckpoint(DriverPluginCheckpointFile, checkpoint); err != nil {
|
||||
return fmt.Errorf("unable to sync from checkpoint: %v", err)
|
||||
}
|
||||
preparedClaims := checkpoint.V1.PreparedClaims
|
||||
|
||||
if preparedClaims[claimUID] == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.unprepareDevices(claimUID, preparedClaims[claimUID]); err != nil {
|
||||
return fmt.Errorf("unprepare failed: %v", err)
|
||||
}
|
||||
|
||||
err := s.cdi.DeleteClaimSpecFile(claimUID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to delete CDI spec file for claim: %v", err)
|
||||
}
|
||||
|
||||
delete(preparedClaims, claimUID)
|
||||
if err := s.checkpointManager.CreateCheckpoint(DriverPluginCheckpointFile, checkpoint); err != nil {
|
||||
return fmt.Errorf("unable to sync to checkpoint: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DeviceState) prepareDevices(claim *resourceapi.ResourceClaim) (PreparedDevices, error) {
|
||||
if claim.Status.Allocation == nil {
|
||||
return nil, fmt.Errorf("claim not yet allocated")
|
||||
}
|
||||
|
||||
// Retrieve the full set of device configs for the driver.
|
||||
configs, err := GetOpaqueDeviceConfigs(
|
||||
configapi.Decoder,
|
||||
consts.DriverName,
|
||||
claim.Status.Allocation.Devices.Config,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting opaque device configs: %v", err)
|
||||
}
|
||||
|
||||
// Add the default GPU Config to the front of the config list with the
|
||||
// lowest precedence. This guarantees there will be at least one config in
|
||||
// the list with len(Requests) == 0 for the lookup below.
|
||||
configs = slices.Insert(configs, 0, &OpaqueDeviceConfig{
|
||||
Requests: []string{},
|
||||
Config: configapi.DefaultGpuConfig(),
|
||||
})
|
||||
|
||||
// Look through the configs and figure out which one will be applied to
|
||||
// each device allocation result based on their order of precedence.
|
||||
configResultsMap := make(map[runtime.Object][]*resourceapi.DeviceRequestAllocationResult)
|
||||
for _, result := range claim.Status.Allocation.Devices.Results {
|
||||
if _, exists := s.allocatable[result.Device]; !exists {
|
||||
return nil, fmt.Errorf("requested GPU is not allocatable: %v", result.Device)
|
||||
}
|
||||
for _, c := range slices.Backward(configs) {
|
||||
if len(c.Requests) == 0 || slices.Contains(c.Requests, result.Request) {
|
||||
configResultsMap[c.Config] = append(configResultsMap[c.Config], &result)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize, validate, and apply all configs associated with devices that
|
||||
// need to be prepared. Track container edits generated from applying the
|
||||
// config to the set of device allocation results.
|
||||
perDeviceCDIContainerEdits := make(PerDeviceCDIContainerEdits)
|
||||
for c, results := range configResultsMap {
|
||||
// Cast the opaque config to a GpuConfig
|
||||
var config *configapi.GpuConfig
|
||||
switch castConfig := c.(type) {
|
||||
case *configapi.GpuConfig:
|
||||
config = castConfig
|
||||
default:
|
||||
return nil, fmt.Errorf("runtime object is not a regognized configuration")
|
||||
}
|
||||
|
||||
// Normalize the config to set any implied defaults.
|
||||
if err := config.Normalize(); err != nil {
|
||||
return nil, fmt.Errorf("error normalizing GPU config: %w", err)
|
||||
}
|
||||
|
||||
// Validate the config to ensure its integrity.
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("error validating GPU config: %w", err)
|
||||
}
|
||||
|
||||
// Apply the config to the list of results associated with it.
|
||||
containerEdits, err := s.applyConfig(config, results)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error applying GPU config: %w", err)
|
||||
}
|
||||
|
||||
// Merge any new container edits with the overall per device map.
|
||||
for k, v := range containerEdits {
|
||||
perDeviceCDIContainerEdits[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Walk through each config and its associated device allocation results
|
||||
// and construct the list of prepared devices to return.
|
||||
var preparedDevices PreparedDevices
|
||||
for _, results := range configResultsMap {
|
||||
for _, result := range results {
|
||||
device := &PreparedDevice{
|
||||
Device: drapbv1.Device{
|
||||
RequestNames: []string{result.Request},
|
||||
PoolName: result.Pool,
|
||||
DeviceName: result.Device,
|
||||
CDIDeviceIDs: s.cdi.GetClaimDevices(string(claim.UID), []string{result.Device}),
|
||||
},
|
||||
ContainerEdits: perDeviceCDIContainerEdits[result.Device],
|
||||
}
|
||||
preparedDevices = append(preparedDevices, device)
|
||||
}
|
||||
}
|
||||
|
||||
return preparedDevices, nil
|
||||
}
|
||||
|
||||
func (s *DeviceState) unprepareDevices(claimUID string, devices PreparedDevices) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyConfig applies a configuration to a set of device allocation results.
|
||||
//
|
||||
// In this example driver there is no actual configuration applied. We simply
|
||||
// define a set of environment variables to be injected into the containers
|
||||
// that include a given device. A real driver would likely need to do some sort
|
||||
// of hardware configuration as well, based on the config passed in.
|
||||
func (s *DeviceState) applyConfig(config *configapi.GpuConfig, results []*resourceapi.DeviceRequestAllocationResult) (PerDeviceCDIContainerEdits, error) {
|
||||
perDeviceEdits := make(PerDeviceCDIContainerEdits)
|
||||
|
||||
for _, result := range results {
|
||||
envs := []string{
|
||||
fmt.Sprintf("GPU_DEVICE_%s=%s", result.Device[4:], result.Device),
|
||||
}
|
||||
|
||||
if config.Sharing != nil {
|
||||
envs = append(envs, fmt.Sprintf("GPU_DEVICE_%s_SHARING_STRATEGY=%s", result.Device[4:], config.Sharing.Strategy))
|
||||
}
|
||||
|
||||
switch {
|
||||
case config.Sharing.IsTimeSlicing():
|
||||
tsconfig, err := config.Sharing.GetTimeSlicingConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get time slicing config for device %v: %w", result.Device, err)
|
||||
}
|
||||
envs = append(envs, fmt.Sprintf("GPU_DEVICE_%s_TIMESLICE_INTERVAL=%v", result.Device[4:], tsconfig.Interval))
|
||||
case config.Sharing.IsSpacePartitioning():
|
||||
spconfig, err := config.Sharing.GetSpacePartitioningConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get space partitioning config for device %v: %w", result.Device, err)
|
||||
}
|
||||
envs = append(envs, fmt.Sprintf("GPU_DEVICE_%s_PARTITION_COUNT=%v", result.Device[4:], spconfig.PartitionCount))
|
||||
}
|
||||
|
||||
edits := &cdispec.ContainerEdits{
|
||||
Env: envs,
|
||||
}
|
||||
|
||||
perDeviceEdits[result.Device] = &cdiapi.ContainerEdits{ContainerEdits: edits}
|
||||
}
|
||||
|
||||
return perDeviceEdits, nil
|
||||
}
|
||||
|
||||
// GetOpaqueDeviceConfigs returns an ordered list of the configs contained in possibleConfigs for this driver.
|
||||
//
|
||||
// Configs can either come from the resource claim itself or from the device
|
||||
// class associated with the request. Configs coming directly from the resource
|
||||
// claim take precedence over configs coming from the device class. Moreover,
|
||||
// configs found later in the list of configs attached to its source take
|
||||
// precedence over configs found earlier in the list for that source.
|
||||
//
|
||||
// All of the configs relevant to the driver from the list of possibleConfigs
|
||||
// will be returned in order of precedence (from lowest to highest). If no
|
||||
// configs are found, nil is returned.
|
||||
func GetOpaqueDeviceConfigs(
|
||||
decoder runtime.Decoder,
|
||||
driverName string,
|
||||
possibleConfigs []resourceapi.DeviceAllocationConfiguration,
|
||||
) ([]*OpaqueDeviceConfig, error) {
|
||||
// Collect all configs in order of reverse precedence.
|
||||
var classConfigs []resourceapi.DeviceAllocationConfiguration
|
||||
var claimConfigs []resourceapi.DeviceAllocationConfiguration
|
||||
var candidateConfigs []resourceapi.DeviceAllocationConfiguration
|
||||
for _, config := range possibleConfigs {
|
||||
switch config.Source {
|
||||
case resourceapi.AllocationConfigSourceClass:
|
||||
classConfigs = append(classConfigs, config)
|
||||
case resourceapi.AllocationConfigSourceClaim:
|
||||
claimConfigs = append(claimConfigs, config)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid config source: %v", config.Source)
|
||||
}
|
||||
}
|
||||
candidateConfigs = append(candidateConfigs, classConfigs...)
|
||||
candidateConfigs = append(candidateConfigs, claimConfigs...)
|
||||
|
||||
// Decode all configs that are relevant for the driver.
|
||||
var resultConfigs []*OpaqueDeviceConfig
|
||||
for _, config := range candidateConfigs {
|
||||
// If this is nil, the driver doesn't support some future API extension
|
||||
// and needs to be updated.
|
||||
if config.DeviceConfiguration.Opaque == nil {
|
||||
return nil, fmt.Errorf("only opaque parameters are supported by this driver")
|
||||
}
|
||||
|
||||
// Configs for different drivers may have been specified because a
|
||||
// single request can be satisfied by different drivers. This is not
|
||||
// an error -- drivers must skip over other driver's configs in order
|
||||
// to support this.
|
||||
if config.DeviceConfiguration.Opaque.Driver != driverName {
|
||||
continue
|
||||
}
|
||||
|
||||
decodedConfig, err := runtime.Decode(decoder, config.DeviceConfiguration.Opaque.Parameters.Raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decoding config parameters: %w", err)
|
||||
}
|
||||
|
||||
resultConfig := &OpaqueDeviceConfig{
|
||||
Requests: config.Requests,
|
||||
Config: decodedConfig,
|
||||
}
|
||||
|
||||
resultConfigs = append(resultConfigs, resultConfig)
|
||||
}
|
||||
|
||||
return resultConfigs, nil
|
||||
}
|
56
cmd/dra-multus-driver/state_test.go
Normal file
56
cmd/dra-multus-driver/state_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
/*
|
||||
* Copyright 2025 The Kubernetes Authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1"
|
||||
)
|
||||
|
||||
func TestPreparedDevicesGetDevices(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
preparedDevices PreparedDevices
|
||||
expected []*drapbv1.Device
|
||||
}{
|
||||
"nil PreparedDevices": {
|
||||
preparedDevices: nil,
|
||||
expected: nil,
|
||||
},
|
||||
"several PreparedDevices": {
|
||||
preparedDevices: PreparedDevices{
|
||||
{Device: drapbv1.Device{DeviceName: "dev1"}},
|
||||
{Device: drapbv1.Device{DeviceName: "dev2"}},
|
||||
{Device: drapbv1.Device{DeviceName: "dev3"}},
|
||||
},
|
||||
expected: []*drapbv1.Device{
|
||||
{DeviceName: "dev1"},
|
||||
{DeviceName: "dev2"},
|
||||
{DeviceName: "dev3"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
devices := test.preparedDevices.GetDevices()
|
||||
assert.Equal(t, test.expected, devices)
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user