Skip to content

Commit

Permalink
Add support for multiple channels within a single ComputeDomain
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Klues <[email protected]>
  • Loading branch information
klueska committed Feb 11, 2025
1 parent 63b5999 commit fb325f3
Show file tree
Hide file tree
Showing 12 changed files with 105 additions and 78 deletions.
6 changes: 4 additions & 2 deletions api/nvidia.com/resource/v1beta1/computedomain.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ type ComputeDomainList struct {
}

// +kubebuilder:validation:XValidation:rule="self == oldSelf", message="A computeDomain.spec is immutable"
// +kubebuilder:validation:XValidation:rule="size(self.resourceClaimTemplates) >= 1",message="The 'resourceClaimTemplates' field must have at least one entry."
// +kubebuilder:validation:XValidation:rule="size(self.resourceClaimTemplates) < 64",message="The 'resourceClaimTemplates' field must have less than 64 entries."

// ComputeDomainSpec provides the spec for a ComputeDomain.
type ComputeDomainSpec struct {
NumNodes int `json:"numNodes"`
ResourceClaimTemplate ComputeDomainResourceClaimTemplate `json:"resourceClaimTemplate"`
NumNodes int `json:"numNodes"`
ResourceClaimTemplates []ComputeDomainResourceClaimTemplate `json:"resourceClaimTemplates"`
}

// ComputeDomainResourceClaimTemplate provides the details of the ResourceClaimTemplate to generate.
Expand Down
8 changes: 6 additions & 2 deletions api/nvidia.com/resource/v1beta1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 6 additions & 5 deletions cmd/compute-domain-controller/computedomain.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ const (
computeDomainLabelKey = "resource.nvidia.com/computeDomain"
computeDomainFinalizer = computeDomainLabelKey

computeDomainDefaultChannelDeviceClass = "compute-domain-default-channel.nvidia.com"
computeDomainChannelDeviceClass = "compute-domain-channel.nvidia.com"
computeDomainDaemonDeviceClass = "compute-domain-daemon.nvidia.com"
computeDomainChannelDeviceClass = "compute-domain-channel.nvidia.com"
computeDomainDaemonDeviceClass = "compute-domain-daemon.nvidia.com"

computeDomainResourceClaimTemplateTargetLabelKey = "resource.nvidia.com/computeDomainTarget"
computeDomainResourceClaimTemplateTargetDaemon = "Daemon"
Expand Down Expand Up @@ -289,8 +288,10 @@ func (m *ComputeDomainManager) onAddOrUpdate(ctx context.Context, obj any) error
return fmt.Errorf("error creating DaemonSet: %w", err)
}

if _, err := m.resourceClaimTemplateManager.Create(ctx, cd.Namespace, cd.Spec.ResourceClaimTemplate.Name, cd); err != nil {
return fmt.Errorf("error creating ResourceClaimTemplate '%s/%s': %w", cd.Namespace, cd.Spec.ResourceClaimTemplate.Name, err)
for i, rct := range cd.Spec.ResourceClaimTemplates {
if _, err := m.resourceClaimTemplateManager.Create(ctx, cd.Namespace, rct.Name, i, cd); err != nil {
return fmt.Errorf("error creating ResourceClaimTemplate '%s/%s': %w", cd.Namespace, rct.Name, err)
}
}

return nil
Expand Down
78 changes: 32 additions & 46 deletions cmd/compute-domain-controller/resourceclaimtemplate.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type ResourceClaimTemplateTemplateData struct {
TargetLabelValue string
DeviceClassName string
DriverName string
ChannelID int
ChannelConfig *nvapi.ComputeDomainChannelConfig
DaemonConfig *nvapi.ComputeDomainDaemonConfig
}
Expand Down Expand Up @@ -166,22 +167,16 @@ func (m *BaseResourceClaimTemplateManager) Delete(ctx context.Context, cdUID str
if err != nil {
return fmt.Errorf("error retrieving ResourceClaimTemplate: %w", err)
}
if len(rcts) > 1 {
return fmt.Errorf("more than one ResourceClaimTemplate found with same ComputeDomain UID")
}
if len(rcts) == 0 {
return nil
}

rct := rcts[0]

if rct.GetDeletionTimestamp() != nil {
return nil
}
for _, rct := range rcts {
if rct.GetDeletionTimestamp() != nil {
continue
}

err = m.config.clientsets.Core.ResourceV1beta1().ResourceClaimTemplates(rct.Namespace).Delete(ctx, rct.Name, metav1.DeleteOptions{})
if err != nil && !errors.IsNotFound(err) {
return fmt.Errorf("erroring deleting ResourceClaimTemplate: %w", err)
err := m.config.clientsets.Core.ResourceV1beta1().ResourceClaimTemplates(rct.Namespace).Delete(ctx, rct.Name, metav1.DeleteOptions{})
if err != nil && !errors.IsNotFound(err) {
return fmt.Errorf("erroring deleting ResourceClaimTemplate: %w", err)
}
}

return nil
Expand All @@ -192,32 +187,26 @@ func (m *BaseResourceClaimTemplateManager) RemoveFinalizer(ctx context.Context,
if err != nil {
return fmt.Errorf("error retrieving ResourceClaimTemplate: %w", err)
}
if len(rcts) > 1 {
return fmt.Errorf("more than one ResourceClaimTemplate found with same ComputeDomain UID")
}
if len(rcts) == 0 {
return nil
}

rct := rcts[0]

if rct.GetDeletionTimestamp() == nil {
return fmt.Errorf("attempting to remove finalizer before ResourceClaimTemplate marked for deletion")
}
for _, rct := range rcts {
if rct.GetDeletionTimestamp() == nil {
return fmt.Errorf("attempting to remove finalizer before ResourceClaimTemplate marked for deletion")
}

newRCT := rct.DeepCopy()
newRCT.Finalizers = []string{}
for _, f := range rct.Finalizers {
if f != computeDomainFinalizer {
newRCT.Finalizers = append(newRCT.Finalizers, f)
newRCT := rct.DeepCopy()
newRCT.Finalizers = []string{}
for _, f := range rct.Finalizers {
if f != computeDomainFinalizer {
newRCT.Finalizers = append(newRCT.Finalizers, f)
}
}
if len(rct.Finalizers) == len(newRCT.Finalizers) {
return nil
}
}
if len(rct.Finalizers) == len(newRCT.Finalizers) {
return nil
}

if _, err = m.config.clientsets.Core.ResourceV1beta1().ResourceClaimTemplates(rct.Namespace).Update(ctx, newRCT, metav1.UpdateOptions{}); err != nil {
return fmt.Errorf("error updating ResourceClaimTemplate: %w", err)
if _, err = m.config.clientsets.Core.ResourceV1beta1().ResourceClaimTemplates(newRCT.Namespace).Update(ctx, newRCT, metav1.UpdateOptions{}); err != nil {
return fmt.Errorf("error updating ResourceClaimTemplate: %w", err)
}
}

return nil
Expand Down Expand Up @@ -274,9 +263,6 @@ func (m *DaemonSetResourceClaimTemplateManager) Create(ctx context.Context, name
daemonConfig.NumNodes = cd.Spec.NumNodes
daemonConfig.DomainID = string(cd.UID)

channelConfig := nvapi.DefaultComputeDomainChannelConfig()
channelConfig.DomainID = string(cd.UID)

templateData := ResourceClaimTemplateTemplateData{
Namespace: namespace,
GenerateName: fmt.Sprintf("%s-daemon-claim-template-", cd.Name),
Expand Down Expand Up @@ -322,16 +308,15 @@ func NewWorkloadResourceClaimTemplateManager(config *ManagerConfig) *WorkloadRes
return m
}

func (m *WorkloadResourceClaimTemplateManager) Create(ctx context.Context, namespace, name string, cd *nvapi.ComputeDomain) (*resourceapi.ResourceClaimTemplate, error) {
func (m *WorkloadResourceClaimTemplateManager) Create(ctx context.Context, namespace, name string, channel int, cd *nvapi.ComputeDomain) (*resourceapi.ResourceClaimTemplate, error) {
rcts, err := getByComputeDomainUID[*resourceapi.ResourceClaimTemplate](ctx, m.informer, string(cd.UID))
if err != nil {
return nil, fmt.Errorf("error retrieving ResourceClaimTemplate: %w", err)
}
if len(rcts) > 1 {
return nil, fmt.Errorf("more than one ResourceClaimTemplate found with same ComputeDomain UID")
}
if len(rcts) == 1 {
return rcts[0], nil
for _, rct := range rcts {
if rct.Namespace == namespace && rct.Name == name {
return rct, nil
}
}

channelConfig := nvapi.DefaultComputeDomainChannelConfig()
Expand All @@ -345,8 +330,9 @@ func (m *WorkloadResourceClaimTemplateManager) Create(ctx context.Context, names
ComputeDomainLabelValue: cd.UID,
TargetLabelKey: computeDomainResourceClaimTemplateTargetLabelKey,
TargetLabelValue: computeDomainResourceClaimTemplateTargetWorkload,
DeviceClassName: computeDomainDefaultChannelDeviceClass,
DeviceClassName: computeDomainChannelDeviceClass,
DriverName: DriverName,
ChannelID: channel,
ChannelConfig: channelConfig,
}

Expand Down
4 changes: 4 additions & 0 deletions cmd/compute-domain-kubelet-plugin/computedomain.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,10 @@ func (m *ComputeDomainManager) AddNodeLabel(ctx context.Context, cdUID string) e
}

currentValue, exists := node.Labels[computeDomainLabelKey]
if exists && currentValue != cdUID {
return fmt.Errorf("label already exists for a different ComputeDomain")
}

if exists && currentValue == cdUID {
return nil
}
Expand Down
13 changes: 8 additions & 5 deletions cmd/compute-domain-kubelet-plugin/device_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (s *DeviceState) Unprepare(ctx context.Context, claimUID string) error {
return nil
}

if err := s.unprepareDevices(ctx, claimUID, preparedClaims[claimUID]); err != nil {
if err := s.unprepareDevices(ctx, claimUID, preparedClaims); err != nil {
return fmt.Errorf("unprepare devices failed: %w", err)
}

Expand Down Expand Up @@ -339,17 +339,20 @@ func (s *DeviceState) prepareDevices(ctx context.Context, claim *resourceapi.Res
return preparedDevices, nil
}

func (s *DeviceState) unprepareDevices(ctx context.Context, claimUID string, devices PreparedDevices) error {
func (s *DeviceState) unprepareDevices(ctx context.Context, claimUID string, preparedClaims PreparedClaims) error {
// Unprepare any ComputeDomain daemons prepared for each group of prepared devices.
for _, group := range devices {
// If a cannel type, remove the ComputeDomain label from the node
for _, group := range preparedClaims[claimUID] {
// If the last channel remaining, remove the ComputeDomain label from the node
if group.ConfigState.Type == ComputeDomainChannelType {
if len(preparedClaims.ComputeDomainChannels()) > 1 {
return nil
}
if err := s.computeDomainManager.RemoveNodeLabel(ctx, group.ConfigState.ComputeDomain); err != nil {
return fmt.Errorf("error removing Node label for ComputeDomain: %w", err)
}
}

// If a daemon type, unprepare the new ComputeDomain daemon.
// If a daemon type, unprepare the new ComputeDomain daemon
if group.ConfigState.Type == ComputeDomainDaemonType {
computeDomainDaemonSettings := s.computeDomainManager.NewSettings(group.ConfigState.ComputeDomain)
if err := computeDomainDaemonSettings.Unprepare(ctx); err != nil {
Expand Down
5 changes: 0 additions & 5 deletions cmd/compute-domain-kubelet-plugin/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,6 @@ func NewDriver(ctx context.Context, config *Config) (*driver, error) {
// Enumerate the set of ComputeDomain daemon devices and publish them
var resources kubeletplugin.Resources
for _, device := range state.allocatable {
// Explicitly exclude ComputeDomain channels from being advertised here. They
// are instead advertised in as a network resource from the control plane.
if device.Type() == ComputeDomainChannelType && device.Channel.ID != 0 {
continue
}
resources.Devices = append(resources.Devices, device.GetDevice())
}

Expand Down
5 changes: 4 additions & 1 deletion cmd/compute-domain-kubelet-plugin/nvlib.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,10 @@ func (l deviceLib) enumerateComputeDomainDaemons(config *Config) (AllocatableDev

func (l deviceLib) getImexChannelCount() (int, error) {
// TODO: Pull this value from /proc/driver/nvidia/params
return 2048, nil
// The default is 2048.
// For now limit this to 64 (which is half the maximum number of devices
// allowed in a ResoureSlice)
return 64, nil
}

func (l deviceLib) getImexChannelMajor() (int, error) {
Expand Down
20 changes: 20 additions & 0 deletions cmd/compute-domain-kubelet-plugin/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,26 @@ func (l PreparedDeviceList) ComputeDomainDaemons() PreparedDeviceList {
return devices
}

func (c PreparedClaims) ComputeDomainChannels() PreparedDeviceList {
var ds PreparedDeviceList
for _, devices := range c {
for _, group := range devices {
ds = append(ds, group.Devices.ComputeDomainChannels()...)
}
}
return ds
}

func (c PreparedClaims) ComputeDomainDaemons() PreparedDeviceList {
var ds PreparedDeviceList
for _, devices := range c {
for _, group := range devices {
ds = append(ds, group.Devices.ComputeDomainDaemons()...)
}
}
return ds
}

func (d PreparedDevices) GetDevices() []*drapbv1.Device {
var devices []*drapbv1.Device
for _, group := range d {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,28 @@ spec:
properties:
numNodes:
type: integer
resourceClaimTemplate:
description: ComputeDomainResourceClaimTemplate provides the details
of the ResourceClaimTemplate to generate.
properties:
name:
type: string
required:
- name
type: object
resourceClaimTemplates:
items:
description: ComputeDomainResourceClaimTemplate provides the details
of the ResourceClaimTemplate to generate.
properties:
name:
type: string
required:
- name
type: object
type: array
required:
- numNodes
- resourceClaimTemplate
- resourceClaimTemplates
type: object
x-kubernetes-validations:
- message: A computeDomain.spec is immutable
rule: self == oldSelf
- message: The 'resourceClaimTemplates' field must have at least one entry.
rule: size(self.resourceClaimTemplates) >= 1
- message: The 'resourceClaimTemplates' field must have less than 64 entries.
rule: size(self.resourceClaimTemplates) < 64
status:
description: ComputeDomainStatus provides the status for a ComputeDomain.
properties:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
apiVersion: resource.k8s.io/v1beta1
kind: DeviceClass
metadata:
name: compute-domain-default-channel.nvidia.com
name: compute-domain-channel.nvidia.com
spec:
selectors:
- cel:
expression: "device.driver == 'compute-domain.nvidia.com' && device.attributes['compute-domain.nvidia.com'].type == 'channel' && device.attributes['compute-domain.nvidia.com'].id == 0"
expression: "device.driver == 'compute-domain.nvidia.com' && device.attributes['compute-domain.nvidia.com'].type == 'channel'"

3 changes: 3 additions & 0 deletions templates/compute-domain-workload-claim-template.tmpl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ spec:
requests:
- name: channel
deviceClassName: {{ .DeviceClassName }}
selectors:
- cel:
expression: "device.attributes['{{ .DriverName }}'].id == {{ .ChannelID }}"
config:
- requests: ["channel"]
opaque:
Expand Down

0 comments on commit fb325f3

Please sign in to comment.