Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract NIM model manifest in all cases for model caching #88

Merged
merged 3 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 97 additions & 61 deletions internal/controller/nimcache_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,73 +495,112 @@ func getSelectedProfiles(nimCache *appsv1alpha1.NIMCache) ([]string, error) {
return nil, nil
}

func (r *NIMCacheReconciler) reconcileModelSelection(ctx context.Context, nimCache *appsv1alpha1.NIMCache) (requeue bool, err error) {
func (r *NIMCacheReconciler) reconcileModelManifest(ctx context.Context, nimCache *appsv1alpha1.NIMCache) (requeue bool, err error) {
logger := r.GetLogger()

// reconcile model selection pod
if isModelSelectionRequired(nimCache) && !isModelSelectionDone(nimCache) {
// Create a temporary pod for parsing model manifest
pod := constructPodSpec(nimCache)
// Add nimCache as owner for watching on status change
if err := controllerutil.SetControllerReference(nimCache, pod, r.GetScheme()); err != nil {
return false, err
}
err := r.createPod(ctx, pod)
if err != nil {
logger.Error(err, "failed to create", "pod", pod.Name)
return false, err
}
// Model manifest is available only for NGC model pullers
if nimCache.Spec.Source.NGC == nil {
return false, nil
}

existingPod := &corev1.Pod{}
err = r.Get(ctx, client.ObjectKey{Name: pod.Name, Namespace: nimCache.Namespace}, existingPod)
if err != nil {
logger.Error(err, "failed to get pod for model selection", "pod", pod.Name)
return false, err
}
existingConfig := &corev1.ConfigMap{}
cmName := getManifestConfigName(nimCache)
err = r.Get(ctx, client.ObjectKey{Name: cmName, Namespace: nimCache.Namespace}, existingConfig)
if err != nil && client.IgnoreNotFound(err) != nil {
logger.Error(err, "failed to get configmap of the model manifest", "name", cmName)
return false, err
}

if existingPod.Status.Phase != corev1.PodRunning {
// requeue request with delay until the pod is ready
return true, nil
}
// No action if the configmap is already created
if err == nil {
return false, nil
}

// Extract manifest file
output, err := r.getPodLogs(ctx, existingPod)
if err != nil {
logger.Error(err, "failed to get pod logs for parsing model manifest file", "pod", pod.Name)
return false, err
}
// Create a configmap by extracting the model manifest
// Create a temporary pod for parsing model manifest
pod := constructPodSpec(nimCache)
// Add nimCache as owner for watching on status change
if err := controllerutil.SetControllerReference(nimCache, pod, r.GetScheme()); err != nil {
return false, err
}
err = r.createPod(ctx, pod)
if err != nil {
logger.Error(err, "failed to create", "pod", pod.Name)
return false, err
}

// Parse the file
manifest, err := nimparser.ParseModelManifestFromRawOutput([]byte(output))
if err != nil {
logger.Error(err, "Failed to parse model manifest from the pod")
return false, err
}
logger.V(2).Info("manifest file", "nimcache", nimCache.Name, "manifest", manifest)
existingPod := &corev1.Pod{}
err = r.Get(ctx, client.ObjectKey{Name: pod.Name, Namespace: nimCache.Namespace}, existingPod)
if err != nil {
logger.Error(err, "failed to get pod for model selection", "pod", pod.Name)
return false, err
}

// Create a ConfigMap with the model manifest file for re-use
err = r.createManifestConfigMap(ctx, nimCache, manifest)
if err != nil {
logger.Error(err, "Failed to create model manifest config map")
return false, err
}
if existingPod.Status.Phase != corev1.PodRunning {
// requeue request with delay until the pod is ready
return true, nil
}

// Extract manifest file
output, err := r.getPodLogs(ctx, existingPod)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For model manifests -- do we have plan to add NGC api (or any other web apis) to query model metadata?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pull an image and cat a file inside looks very hacky......

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is not planned in the near term. The manifest is embedded in the NIM container itself. So we need to run the container to get that (either by cp or logging)

if err != nil {
logger.Error(err, "failed to get pod logs for parsing model manifest file", "pod", pod.Name)
return false, err
}

// Parse the file
manifest, err := nimparser.ParseModelManifestFromRawOutput([]byte(output))
if err != nil {
logger.Error(err, "Failed to parse model manifest from the pod")
return false, err
}
logger.V(2).Info("manifest file", "nimcache", nimCache.Name, "manifest", manifest)

// Create a ConfigMap with the model manifest file for re-use
err = r.createManifestConfigMap(ctx, nimCache, manifest)
if err != nil {
logger.Error(err, "Failed to create model manifest config map")
return false, err
}

// Model manifest is successfully extracted, cleanup temporary pod
err = r.Delete(ctx, existingPod)
if err != nil && !errors.IsNotFound(err) {
logger.Error(err, "failed to delete", "pod", pod.Name)
// requeue request with delay until the pod is cleaned up
// this is required as NIM containers are resource heavy
return true, err
}
return false, nil
}

func (r *NIMCacheReconciler) reconcileModelSelection(ctx context.Context, nimCache *appsv1alpha1.NIMCache) error {
logger := r.GetLogger()

// reconcile model selection pod
if isModelSelectionRequired(nimCache) && !isModelSelectionDone(nimCache) {
var discoveredGPUs []string
// If no specific GPUs are provided, then auto-detect GPUs in the cluster for profile selection
if len(nimCache.Spec.Source.NGC.Model.GPUs) == 0 {
gpusByNode, err := r.GetNodeGPUProducts(ctx)
if err != nil {
logger.Error(err, "Failed to get gpus in the cluster")
return false, err
return err
}
discoveredGPUs = getUniqueGPUProducts(gpusByNode)
}

// Get the model manifest from the config
nimManifest, err := r.extractNIMManifest(ctx, getManifestConfigName(nimCache), nimCache.GetNamespace())
if err != nil {
return fmt.Errorf("failed to get model manifest config file: %w", err)
}

// Match profiles with user input
profiles, err := nimparser.MatchProfiles(nimCache.Spec.Source.NGC.Model, *manifest, discoveredGPUs)
profiles, err := nimparser.MatchProfiles(nimCache.Spec.Source.NGC.Model, *nimManifest, discoveredGPUs)
if err != nil {
logger.Error(err, "Failed to match profiles for given model parameters")
return false, err
return err
}

// Add the annotation to the NIMCache object
Expand All @@ -572,25 +611,16 @@ func (r *NIMCacheReconciler) reconcileModelSelection(ctx context.Context, nimCac
profilesJSON, err := json.Marshal(profiles)
if err != nil {
logger.Error(err, "unable to marshal profiles to JSON")
return false, err
return err
}

nimCache.Annotations[SelectedNIMProfilesAnnotationKey] = string(profilesJSON)
if err := r.Update(ctx, nimCache); err != nil {
logger.Error(err, "unable to update NIMCache with selected profiles annotation")
return false, err
}

// Selected profiles updated, cleanup temporary pod
err = r.Delete(ctx, existingPod)
if err != nil && !errors.IsNotFound(err) {
logger.Error(err, "failed to delete", "pod", pod.Name)
// requeue request with delay until the pod is cleaned up
// this is required as NIM containers are resource heavy
return true, err
return err
}
}
return false, nil
return nil
}

func (r *NIMCacheReconciler) reconcileJob(ctx context.Context, nimCache *appsv1alpha1.NIMCache) error {
Expand Down Expand Up @@ -755,10 +785,9 @@ func (r *NIMCacheReconciler) reconcileNIMCache(ctx context.Context, nimCache *ap
return ctrl.Result{}, err
}

// Reconcile NIM model selection
requeue, err := r.reconcileModelSelection(ctx, nimCache)
requeue, err := r.reconcileModelManifest(ctx, nimCache)
if err != nil {
logger.Error(err, "reconciliation of model selection failed", "pod", getPodName(nimCache))
logger.Error(err, "reconciliation to extract model manifest failed", "pod", getPodName(nimCache))
return ctrl.Result{}, err
}

Expand All @@ -767,6 +796,13 @@ func (r *NIMCacheReconciler) reconcileNIMCache(ctx context.Context, nimCache *ap
return ctrl.Result{RequeueAfter: time.Second * 30}, err
}

// Reconcile NIM model selection
err = r.reconcileModelSelection(ctx, nimCache)
if err != nil {
logger.Error(err, "reconciliation of model selection failed")
return ctrl.Result{}, err
}

// Reconcile caching Job
err = r.reconcileJob(ctx, nimCache)
if err != nil {
Expand Down
25 changes: 25 additions & 0 deletions internal/controller/nimcache_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,31 @@ var _ = Describe("NIMCache Controller", func() {
Client: client,
scheme: scheme,
}

nimCache := &appsv1alpha1.NIMCache{
ObjectMeta: metav1.ObjectMeta{
Name: "test-nimcache",
Namespace: "default",
},
Spec: appsv1alpha1.NIMCacheSpec{
Source: appsv1alpha1.NIMSource{NGC: &appsv1alpha1.NGCSource{ModelPuller: "nvcr.io/nim:test", PullSecret: "my-secret"}},
},
}

// Create a model manifest configmap, as we cannot run a sample NIM container to extract for tests
filePath := filepath.Join("testdata", "manifest_trtllm.yaml")
manifestData, err := nimparser.ParseModelManifest(filePath)
Expect(err).NotTo(HaveOccurred())
Expect(*manifestData).To(HaveLen(2))

err = reconciler.createManifestConfigMap(context.TODO(), nimCache, manifestData)
Expect(err).NotTo(HaveOccurred())

// Verify that the ConfigMap was created
createdConfigMap := &corev1.ConfigMap{}
err = client.Get(context.TODO(), types.NamespacedName{Name: getManifestConfigName(nimCache), Namespace: "default"}, createdConfigMap)
Expect(err).NotTo(HaveOccurred())
Expect(createdConfigMap.Data).To(HaveKey("model_manifest.yaml"))
})

AfterEach(func() {
Expand Down