Skip to content

Commit 31f723a

Browse files
authored
Extract NIM model manifest in all cases for model caching (#88)
* Extract NIM model manifest in all cases for model caching Signed-off-by: Shiva Krishna, Merla <[email protected]> * Update log statement Signed-off-by: Shiva Krishna, Merla <[email protected]> * Extract model manifest config only from NGC model pullers (NIM images) Signed-off-by: Shiva Krishna, Merla <[email protected]> --------- Signed-off-by: Shiva Krishna, Merla <[email protected]>
1 parent fd6ce53 commit 31f723a

File tree

2 files changed

+122
-61
lines changed

2 files changed

+122
-61
lines changed

internal/controller/nimcache_controller.go

Lines changed: 97 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -495,73 +495,112 @@ func getSelectedProfiles(nimCache *appsv1alpha1.NIMCache) ([]string, error) {
495495
return nil, nil
496496
}
497497

498-
func (r *NIMCacheReconciler) reconcileModelSelection(ctx context.Context, nimCache *appsv1alpha1.NIMCache) (requeue bool, err error) {
498+
func (r *NIMCacheReconciler) reconcileModelManifest(ctx context.Context, nimCache *appsv1alpha1.NIMCache) (requeue bool, err error) {
499499
logger := r.GetLogger()
500500

501-
// reconcile model selection pod
502-
if isModelSelectionRequired(nimCache) && !isModelSelectionDone(nimCache) {
503-
// Create a temporary pod for parsing model manifest
504-
pod := constructPodSpec(nimCache)
505-
// Add nimCache as owner for watching on status change
506-
if err := controllerutil.SetControllerReference(nimCache, pod, r.GetScheme()); err != nil {
507-
return false, err
508-
}
509-
err := r.createPod(ctx, pod)
510-
if err != nil {
511-
logger.Error(err, "failed to create", "pod", pod.Name)
512-
return false, err
513-
}
501+
// Model manifest is available only for NGC model pullers
502+
if nimCache.Spec.Source.NGC == nil {
503+
return false, nil
504+
}
514505

515-
existingPod := &corev1.Pod{}
516-
err = r.Get(ctx, client.ObjectKey{Name: pod.Name, Namespace: nimCache.Namespace}, existingPod)
517-
if err != nil {
518-
logger.Error(err, "failed to get pod for model selection", "pod", pod.Name)
519-
return false, err
520-
}
506+
existingConfig := &corev1.ConfigMap{}
507+
cmName := getManifestConfigName(nimCache)
508+
err = r.Get(ctx, client.ObjectKey{Name: cmName, Namespace: nimCache.Namespace}, existingConfig)
509+
if err != nil && client.IgnoreNotFound(err) != nil {
510+
logger.Error(err, "failed to get configmap of the model manifest", "name", cmName)
511+
return false, err
512+
}
521513

522-
if existingPod.Status.Phase != corev1.PodRunning {
523-
// requeue request with delay until the pod is ready
524-
return true, nil
525-
}
514+
// No action if the configmap is already created
515+
if err == nil {
516+
return false, nil
517+
}
526518

527-
// Extract manifest file
528-
output, err := r.getPodLogs(ctx, existingPod)
529-
if err != nil {
530-
logger.Error(err, "failed to get pod logs for parsing model manifest file", "pod", pod.Name)
531-
return false, err
532-
}
519+
// Create a configmap by extracting the model manifest
520+
// Create a temporary pod for parsing model manifest
521+
pod := constructPodSpec(nimCache)
522+
// Add nimCache as owner for watching on status change
523+
if err := controllerutil.SetControllerReference(nimCache, pod, r.GetScheme()); err != nil {
524+
return false, err
525+
}
526+
err = r.createPod(ctx, pod)
527+
if err != nil {
528+
logger.Error(err, "failed to create", "pod", pod.Name)
529+
return false, err
530+
}
533531

534-
// Parse the file
535-
manifest, err := nimparser.ParseModelManifestFromRawOutput([]byte(output))
536-
if err != nil {
537-
logger.Error(err, "Failed to parse model manifest from the pod")
538-
return false, err
539-
}
540-
logger.V(2).Info("manifest file", "nimcache", nimCache.Name, "manifest", manifest)
532+
existingPod := &corev1.Pod{}
533+
err = r.Get(ctx, client.ObjectKey{Name: pod.Name, Namespace: nimCache.Namespace}, existingPod)
534+
if err != nil {
535+
logger.Error(err, "failed to get pod for model selection", "pod", pod.Name)
536+
return false, err
537+
}
541538

542-
// Create a ConfigMap with the model manifest file for re-use
543-
err = r.createManifestConfigMap(ctx, nimCache, manifest)
544-
if err != nil {
545-
logger.Error(err, "Failed to create model manifest config map")
546-
return false, err
547-
}
539+
if existingPod.Status.Phase != corev1.PodRunning {
540+
// requeue request with delay until the pod is ready
541+
return true, nil
542+
}
548543

544+
// Extract manifest file
545+
output, err := r.getPodLogs(ctx, existingPod)
546+
if err != nil {
547+
logger.Error(err, "failed to get pod logs for parsing model manifest file", "pod", pod.Name)
548+
return false, err
549+
}
550+
551+
// Parse the file
552+
manifest, err := nimparser.ParseModelManifestFromRawOutput([]byte(output))
553+
if err != nil {
554+
logger.Error(err, "Failed to parse model manifest from the pod")
555+
return false, err
556+
}
557+
logger.V(2).Info("manifest file", "nimcache", nimCache.Name, "manifest", manifest)
558+
559+
// Create a ConfigMap with the model manifest file for re-use
560+
err = r.createManifestConfigMap(ctx, nimCache, manifest)
561+
if err != nil {
562+
logger.Error(err, "Failed to create model manifest config map")
563+
return false, err
564+
}
565+
566+
// Model manifest is successfully extracted, cleanup temporary pod
567+
err = r.Delete(ctx, existingPod)
568+
if err != nil && !errors.IsNotFound(err) {
569+
logger.Error(err, "failed to delete", "pod", pod.Name)
570+
// requeue request with delay until the pod is cleaned up
571+
// this is required as NIM containers are resource heavy
572+
return true, err
573+
}
574+
return false, nil
575+
}
576+
577+
func (r *NIMCacheReconciler) reconcileModelSelection(ctx context.Context, nimCache *appsv1alpha1.NIMCache) error {
578+
logger := r.GetLogger()
579+
580+
// reconcile model selection pod
581+
if isModelSelectionRequired(nimCache) && !isModelSelectionDone(nimCache) {
549582
var discoveredGPUs []string
550583
// If no specific GPUs are provided, then auto-detect GPUs in the cluster for profile selection
551584
if len(nimCache.Spec.Source.NGC.Model.GPUs) == 0 {
552585
gpusByNode, err := r.GetNodeGPUProducts(ctx)
553586
if err != nil {
554587
logger.Error(err, "Failed to get gpus in the cluster")
555-
return false, err
588+
return err
556589
}
557590
discoveredGPUs = getUniqueGPUProducts(gpusByNode)
558591
}
559592

593+
// Get the model manifest from the config
594+
nimManifest, err := r.extractNIMManifest(ctx, getManifestConfigName(nimCache), nimCache.GetNamespace())
595+
if err != nil {
596+
return fmt.Errorf("failed to get model manifest config file: %w", err)
597+
}
598+
560599
// Match profiles with user input
561-
profiles, err := nimparser.MatchProfiles(nimCache.Spec.Source.NGC.Model, *manifest, discoveredGPUs)
600+
profiles, err := nimparser.MatchProfiles(nimCache.Spec.Source.NGC.Model, *nimManifest, discoveredGPUs)
562601
if err != nil {
563602
logger.Error(err, "Failed to match profiles for given model parameters")
564-
return false, err
603+
return err
565604
}
566605

567606
// Add the annotation to the NIMCache object
@@ -572,25 +611,16 @@ func (r *NIMCacheReconciler) reconcileModelSelection(ctx context.Context, nimCac
572611
profilesJSON, err := json.Marshal(profiles)
573612
if err != nil {
574613
logger.Error(err, "unable to marshal profiles to JSON")
575-
return false, err
614+
return err
576615
}
577616

578617
nimCache.Annotations[SelectedNIMProfilesAnnotationKey] = string(profilesJSON)
579618
if err := r.Update(ctx, nimCache); err != nil {
580619
logger.Error(err, "unable to update NIMCache with selected profiles annotation")
581-
return false, err
582-
}
583-
584-
// Selected profiles updated, cleanup temporary pod
585-
err = r.Delete(ctx, existingPod)
586-
if err != nil && !errors.IsNotFound(err) {
587-
logger.Error(err, "failed to delete", "pod", pod.Name)
588-
// requeue request with delay until the pod is cleaned up
589-
// this is required as NIM containers are resource heavy
590-
return true, err
620+
return err
591621
}
592622
}
593-
return false, nil
623+
return nil
594624
}
595625

596626
func (r *NIMCacheReconciler) reconcileJob(ctx context.Context, nimCache *appsv1alpha1.NIMCache) error {
@@ -755,10 +785,9 @@ func (r *NIMCacheReconciler) reconcileNIMCache(ctx context.Context, nimCache *ap
755785
return ctrl.Result{}, err
756786
}
757787

758-
// Reconcile NIM model selection
759-
requeue, err := r.reconcileModelSelection(ctx, nimCache)
788+
requeue, err := r.reconcileModelManifest(ctx, nimCache)
760789
if err != nil {
761-
logger.Error(err, "reconciliation of model selection failed", "pod", getPodName(nimCache))
790+
logger.Error(err, "reconciliation to extract model manifest failed", "pod", getPodName(nimCache))
762791
return ctrl.Result{}, err
763792
}
764793

@@ -767,6 +796,13 @@ func (r *NIMCacheReconciler) reconcileNIMCache(ctx context.Context, nimCache *ap
767796
return ctrl.Result{RequeueAfter: time.Second * 30}, err
768797
}
769798

799+
// Reconcile NIM model selection
800+
err = r.reconcileModelSelection(ctx, nimCache)
801+
if err != nil {
802+
logger.Error(err, "reconciliation of model selection failed")
803+
return ctrl.Result{}, err
804+
}
805+
770806
// Reconcile caching Job
771807
err = r.reconcileJob(ctx, nimCache)
772808
if err != nil {

internal/controller/nimcache_controller_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,31 @@ var _ = Describe("NIMCache Controller", func() {
6060
Client: client,
6161
scheme: scheme,
6262
}
63+
64+
nimCache := &appsv1alpha1.NIMCache{
65+
ObjectMeta: metav1.ObjectMeta{
66+
Name: "test-nimcache",
67+
Namespace: "default",
68+
},
69+
Spec: appsv1alpha1.NIMCacheSpec{
70+
Source: appsv1alpha1.NIMSource{NGC: &appsv1alpha1.NGCSource{ModelPuller: "nvcr.io/nim:test", PullSecret: "my-secret"}},
71+
},
72+
}
73+
74+
// Create a model manifest configmap, as we cannot run a sample NIM container to extract for tests
75+
filePath := filepath.Join("testdata", "manifest_trtllm.yaml")
76+
manifestData, err := nimparser.ParseModelManifest(filePath)
77+
Expect(err).NotTo(HaveOccurred())
78+
Expect(*manifestData).To(HaveLen(2))
79+
80+
err = reconciler.createManifestConfigMap(context.TODO(), nimCache, manifestData)
81+
Expect(err).NotTo(HaveOccurred())
82+
83+
// Verify that the ConfigMap was created
84+
createdConfigMap := &corev1.ConfigMap{}
85+
err = client.Get(context.TODO(), types.NamespacedName{Name: getManifestConfigName(nimCache), Namespace: "default"}, createdConfigMap)
86+
Expect(err).NotTo(HaveOccurred())
87+
Expect(createdConfigMap.Data).To(HaveKey("model_manifest.yaml"))
6388
})
6489

6590
AfterEach(func() {

0 commit comments

Comments
 (0)