1818    V1beta1TrialTemplate ,
1919)
2020from  kubeflow .katib .constants  import  constants 
21+ from  kubeflow .katib .types  import  types 
2122from  kubeflow .storage_initializer .hugging_face  import  (
2223    HuggingFaceDatasetParams ,
2324    HuggingFaceModelParams ,
2425    HuggingFaceTrainerParams ,
2526)
26- from  kubernetes .client  import  V1ObjectMeta 
27+ from  kubeflow .training .models  import  KubeflowOrgV1PyTorchJob 
28+ from  kubernetes .client  import  V1Job , V1ObjectMeta 
2729
2830PVC_FAILED  =  "pvc creation failed" 
2931
@@ -476,16 +478,37 @@ def create_experiment(
476478                    learning_rate = katib .search .double (min = 1e-05 , max = 5e-05 ),
477479                ),
478480            ),
481+             "resources_per_trial" : types .TrainerResources (
482+                 num_workers = 2 ,
483+                 num_procs_per_worker = 2 ,
484+                 resources_per_worker = {"gpu" : "2" },
485+             ),
479486        },
480487        RuntimeError ,
481488    ),
482489    (
483-         "valid flow with custom objective tuning " ,
490+         "valid flow with custom objective function and Job as Trial " ,
484491        {
485492            "name" : "tune_test" ,
486493            "objective" : lambda  x : print (f"a={ x }  ),
487494            "parameters" : {"a" : katib .search .int (min = 10 , max = 100 )},
488495            "objective_metric_name" : "a" ,
496+             "resources_per_trial" : {"gpu" : "2" },
497+         },
498+         TEST_RESULT_SUCCESS ,
499+     ),
500+     (
501+         "valid flow with custom objective function and PyTorchJob as Trial" ,
502+         {
503+             "name" : "tune_test" ,
504+             "objective" : lambda  x : print (f"a={ x }  ),
505+             "parameters" : {"a" : katib .search .int (min = 10 , max = 100 )},
506+             "objective_metric_name" : "a" ,
507+             "resources_per_trial" : types .TrainerResources (
508+                 num_workers = 2 ,
509+                 num_procs_per_worker = 2 ,
510+                 resources_per_worker = {"gpu" : "2" },
511+             ),
489512        },
490513        TEST_RESULT_SUCCESS ,
491514    ),
@@ -508,6 +531,11 @@ def create_experiment(
508531                    learning_rate = katib .search .double (min = 1e-05 , max = 5e-05 ),
509532                ),
510533            ),
534+             "resources_per_trial" : types .TrainerResources (
535+                 num_workers = 2 ,
536+                 num_procs_per_worker = 2 ,
537+                 resources_per_worker = {"gpu" : "2" },
538+             ),
511539            "objective_metric_name" : "train_loss" ,
512540            "objective_type" : "minimize" ,
513541        },
@@ -597,7 +625,10 @@ def test_tune(katib_client, test_name, kwargs, expected_output):
597625                call_args  =  mock_create_experiment .call_args 
598626                experiment  =  call_args [0 ][0 ]
599627
600-                 if  test_name  ==  "valid flow with custom objective tuning" :
628+                 if  (
629+                     test_name 
630+                     ==  "valid flow with custom objective function and Job as Trial" 
631+                 ):
601632                    # Verify input_params 
602633                    args_content  =  "" .join (
603634                        experiment .spec .trial_template .trial_spec .spec .template .spec .containers [
@@ -623,6 +654,18 @@ def test_tune(katib_client, test_name, kwargs, expected_output):
623654                        objective_metric_name = "a" ,
624655                        additional_metric_names = [],
625656                    )
657+                     # Verity Trial spec 
658+                     assert  isinstance (experiment .spec .trial_template .trial_spec , V1Job )
659+ 
660+                 elif  (
661+                     test_name 
662+                     ==  "valid flow with custom objective function and PyTorchJob as Trial" 
663+                 ):
664+                     # Verity Trial spec 
665+                     assert  isinstance (
666+                         experiment .spec .trial_template .trial_spec ,
667+                         KubeflowOrgV1PyTorchJob ,
668+                     )
626669
627670                elif  test_name  ==  "valid flow with external model tuning" :
628671                    # Verify input_params 
0 commit comments