Skip to content

Commit

Permalink
[GSoC] Add New Parameter in tune (#2369)
Browse files Browse the repository at this point in the history
* chore: add metrics_collector_config in tune function.

Signed-off-by: Electronic-Waste <[email protected]>

* rebase: rebase feat/new-param-tune to master.

Signed-off-by: Electronic-Waste <[email protected]>

* chore: add metrics collector kind list in comment.

Signed-off-by: Electronic-Waste <[email protected]>

* fix: always pass Trial name to the training container.

Signed-off-by: Electronic-Waste <[email protected]>

* fix: delete passing env variable logics in katib_client.py

Signed-off-by: Electronic-Waste <[email protected]>

* fix: passing env variable KATIB_TRIAL_NAME in the webhook of pod.

Signed-off-by: Electronic-Waste <[email protected]>

* fix: pass env variable KATIB_TRIAL_NAME only to the primary container.

Signed-off-by: Electronic-Waste <[email protected]>

* chore: add report_metrics in post_gen.py.

Signed-off-by: Electronic-Waste <[email protected]>

* fix: change nil error to allErrs(deleted by accident).

Signed-off-by: Electronic-Waste <[email protected]>

* fix: fix lint error in inject_webhook.go.

Signed-off-by: Electronic-Waste <[email protected]>

* fix: wrap env variables passing logics into mutatePodEnv.

Signed-off-by: Electronic-Waste <[email protected]>

* chore: add unit tests for mutatePodEnv.

Signed-off-by: Electronic-Waste <[email protected]>

* fix: delete protocmp.

Signed-off-by: Electronic-Waste <[email protected]>

---------

Signed-off-by: Electronic-Waste <[email protected]>
  • Loading branch information
Electronic-Waste authored Jul 18, 2024
1 parent a3dd708 commit a8840f2
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 5 deletions.
2 changes: 2 additions & 0 deletions hack/gen-python-sdk/post_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def _rewrite_helper(input_file, output_file, rewrite_rules):
if (output_file == "sdk/python/v1beta1/kubeflow/katib/__init__.py"):
lines.append("# Import Katib API client.\n")
lines.append("from kubeflow.katib.api.katib_client import KatibClient\n")
lines.append("# Import Katib report metrics functions")
lines.append("from kubeflow.katib.api.report_metrics import report_metrics")
lines.append("# Import Katib helper functions.\n")
lines.append("import kubeflow.katib.api.search as search\n")
lines.append("# Import Katib helper constants.\n")
Expand Down
4 changes: 2 additions & 2 deletions pkg/apis/controller/common/v1beta1/common_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ const (
CustomCollector CollectorKind = "Custom"

// When model training source code persists metrics into persistent layer
// directly, metricsCollector isn't in need, and its kind is "noneCollector"
NoneCollector CollectorKind = "None"
// directly, sidecar container isn't in need, and its kind is "pushCollector"
PushCollector CollectorKind = "Push"

MetricsVolume = "metrics-volume"
)
Expand Down
3 changes: 3 additions & 0 deletions pkg/controller.v1beta1/consts/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ const (
// resources list which can be used as trial template
ConfigTrialResources = "trial-resources"

// EnvTrialName is the env variable of Trial name
EnvTrialName = "KATIB_TRIAL_NAME"

// LabelExperimentName is the label of experiment name.
LabelExperimentName = "katib.kubeflow.org/experiment"
// LabelSuggestionName is the label of suggestion name.
Expand Down
2 changes: 1 addition & 1 deletion pkg/webhook/v1beta1/experiment/validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ func (g *DefaultValidator) validateMetricsCollector(inst *experimentsv1beta1.Exp
}
// TODO(hougangliu): log warning message if some field will not be used for the metricsCollector kind
switch mcKind {
case commonapiv1beta1.NoneCollector, commonapiv1beta1.StdOutCollector:
case commonapiv1beta1.PushCollector, commonapiv1beta1.StdOutCollector:
return allErrs
case commonapiv1beta1.FileCollector:
if mcSpec.Source == nil || mcSpec.Source.FileSystemPath == nil ||
Expand Down
11 changes: 9 additions & 2 deletions pkg/webhook/v1beta1/pod/inject_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,22 @@ func (s *SidecarInjector) Mutate(pod *v1.Pod, namespace string) (*v1.Pod, error)
// Add Katib Trial labels to the Pod metadata.
mutatePodMetadata(mutatedPod, trial)

// Add env variables to the Pod's primary container.
// We add this function because of push-based metrics collection function `report_metrics` in Python SDK.
// Currently, we only pass the Trial name as env variable `KATIB_TRIAL_NAME` to the training container.
if err := mutatePodEnv(mutatedPod, trial); err != nil {
return nil, err
}

// Do the following mutation only for the Primary pod.
// If PrimaryPodLabel is not set we mutate all pods which are related to Trial job.
// Otherwise, mutate pod only with the appropriate labels.
if trial.Spec.PrimaryPodLabels != nil && !isPrimaryPod(pod.Labels, trial.Spec.PrimaryPodLabels) {
return mutatedPod, nil
}

// If Metrics Collector in None, skip the mutation.
if trial.Spec.MetricsCollector.Collector.Kind == common.NoneCollector {
// If Metrics Collector is Push, skip the mutation.
if trial.Spec.MetricsCollector.Collector.Kind == common.PushCollector {
return mutatedPod, nil
}

Expand Down
102 changes: 102 additions & 0 deletions pkg/webhook/v1beta1/pod/inject_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/onsi/gomega"
appsv1 "k8s.io/api/apps/v1"
batchv1 "k8s.io/api/batch/v1"
Expand Down Expand Up @@ -1067,3 +1069,103 @@ func TestMutatePodMetadata(t *testing.T) {
}
}
}

func TestMutatePodEnv(t *testing.T) {
testcases := map[string]struct {
pod *v1.Pod
trial *trialsv1beta1.Trial
mutatedPod *v1.Pod
wantError error
}{
"Valid case for mutating Pod's env variable": {
pod: &v1.Pod{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "training-container",
},
},
},
},
trial: &trialsv1beta1.Trial{
Spec: trialsv1beta1.TrialSpec{
PrimaryContainerName: "training-container",
},
},
mutatedPod: &v1.Pod{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "training-container",
Env: []v1.EnvVar{
{
Name: consts.EnvTrialName,
ValueFrom: &v1.EnvVarSource{
FieldRef: &v1.ObjectFieldSelector{
FieldPath: fmt.Sprintf("metadata.labels['%s']", consts.LabelTrialName),
},
},
},
},
},
},
},
},
},
"Mismatch for Pod name and primaryContainerName in Trial": {
pod: &v1.Pod{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "training-container",
},
},
},
},
trial: &trialsv1beta1.Trial{
Spec: trialsv1beta1.TrialSpec{
PrimaryContainerName: "training-containers",
},
},
mutatedPod: &v1.Pod{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "training-container",
},
},
},
},
wantError: fmt.Errorf(
"Unable to find primary container %v in mutated pod containers %v",
"training-containers",
[]v1.Container{
{
Name: "training-container",
},
},
),
},
}

for name, testcase := range testcases {
t.Run(name, func(t *testing.T) {
err := mutatePodEnv(testcase.pod, testcase.trial)
// Compare error with expected error
if testcase.wantError != nil && err != nil {
if diff := cmp.Diff(testcase.wantError.Error(), err.Error()); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
}
} else if testcase.wantError != nil || err != nil {
t.Errorf(
"Unexpected error (-want,+got):\n%s",
cmp.Diff(testcase.wantError, err, cmpopts.EquateErrors()),
)
}
// Compare Pod with expected pod after mutation
if diff := cmp.Diff(testcase.mutatedPod, testcase.pod); len(diff) != 0 {
t.Errorf("Unexpected mutated result (-want,+got):\n%s", diff)
}
})
}
}
27 changes: 27 additions & 0 deletions pkg/webhook/v1beta1/pod/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,33 @@ func mutatePodMetadata(pod *v1.Pod, trial *trialsv1beta1.Trial) {
pod.Labels = podLabels
}

func mutatePodEnv(pod *v1.Pod, trial *trialsv1beta1.Trial) error {
// Search for the primary container
index := getPrimaryContainerIndex(pod.Spec.Containers, trial.Spec.PrimaryContainerName)
if index >= 0 {
if pod.Spec.Containers[index].Env == nil {
pod.Spec.Containers[index].Env = []v1.EnvVar{}
}

// Pass env variable KATIB_TRIAL_NAME to the primary container using fieldPath
pod.Spec.Containers[index].Env = append(
pod.Spec.Containers[index].Env,
v1.EnvVar{
Name: consts.EnvTrialName,
ValueFrom: &v1.EnvVarSource{
FieldRef: &v1.ObjectFieldSelector{
FieldPath: fmt.Sprintf("metadata.labels['%s']", consts.LabelTrialName),
},
},
},
)
return nil
} else {
return fmt.Errorf("Unable to find primary container %v in mutated pod containers %v",
trial.Spec.PrimaryContainerName, pod.Spec.Containers)
}
}

func getSidecarContainerName(cKind common.CollectorKind) string {
if cKind == common.StdOutCollector || cKind == common.FileCollector {
return mccommon.MetricLoggerCollectorContainerName
Expand Down
10 changes: 10 additions & 0 deletions sdk/python/v1beta1/kubeflow/katib/api/katib_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def tune(
retain_trials: bool = False,
packages_to_install: List[str] = None,
pip_index_url: str = "https://pypi.org/simple",
metrics_collector_config: Dict[str, Any] = {"kind": "StdOut"},
):
"""Create HyperParameter Tuning Katib Experiment from the objective function.
Expand Down Expand Up @@ -248,6 +249,9 @@ def tune(
to the base image packages. These packages are installed before
executing the objective function.
pip_index_url: The PyPI url from which to install Python packages.
metrics_collector_config: Specify the config of metrics collector,
for example, `metrics_collector_config = {"kind": "Push"}`.
Currently, we only support `StdOut` and `Push` metrics collector.
Raises:
ValueError: Function arguments have incorrect type or value.
Expand Down Expand Up @@ -380,6 +384,12 @@ def tune(
f"Incorrect value for env_per_trial: {env_per_trial}"
)

# Add metrics collector to the Katib Experiment.
# Up to now, We only support parameter `kind`, of which default value is `StdOut`, to specify the kind of metrics collector.
experiment.spec.metrics_collector = models.V1beta1MetricsCollectorSpec(
collector=models.V1beta1CollectorSpec(kind=metrics_collector_config["kind"])
)

# Create Trial specification.
trial_spec = client.V1Job(
api_version="batch/v1",
Expand Down

0 comments on commit a8840f2

Please sign in to comment.