diff --git a/pkg/apis/controller/common/v1beta1/common_types.go b/pkg/apis/controller/common/v1beta1/common_types.go index 8722e8a474d..d2109ac18c2 100644 --- a/pkg/apis/controller/common/v1beta1/common_types.go +++ b/pkg/apis/controller/common/v1beta1/common_types.go @@ -17,7 +17,8 @@ limitations under the License. package v1beta1 import ( - v1 "k8s.io/api/core/v1" + corev1 "k8s.io/api/core/v1" + appsv1 "k8s.io/api/apps/v1" ) // AlgorithmSpec is the specification for a HP or NAS algorithm. @@ -28,6 +29,9 @@ type AlgorithmSpec struct { // Key-value pairs representing settings for suggestion algorithms. AlgorithmSettings []AlgorithmSetting `json:"algorithmSettings,omitempty"` + + // Suggestion service Deployment spec + SuggestionSpec appsv1.DeploymentSpec `json:suggestionSpec,omitempty` } // AlgorithmSetting represents key-value pair for HP or NAS algorithm settings. @@ -163,7 +167,7 @@ type MetricsCollectorSpec struct { type SourceSpec struct { // Model-train source code can expose metrics by http, such as HTTP endpoint in // prometheus metric format - HttpGet *v1.HTTPGetAction `json:"httpGet,omitempty"` + HttpGet *corev1.HTTPGetAction `json:"httpGet,omitempty"` // During training model, metrics may be persisted into local file in source // code, such as tfEvent use case FileSystemPath *FileSystemPath `json:"fileSystemPath,omitempty"` @@ -230,5 +234,5 @@ const ( type CollectorSpec struct { Kind CollectorKind `json:"kind,omitempty"` // When kind is "customCollector", this field will be used - CustomCollector *v1.Container `json:"customCollector,omitempty"` + CustomCollector *corev1.Container `json:"customCollector,omitempty"` } diff --git a/pkg/controller.v1beta1/suggestion/composer/composer.go b/pkg/controller.v1beta1/suggestion/composer/composer.go index 8cc508bad9d..dbddc1a4adb 100644 --- a/pkg/controller.v1beta1/suggestion/composer/composer.go +++ b/pkg/controller.v1beta1/suggestion/composer/composer.go @@ -119,8 +119,8 @@ func (g *General) DesiredDeployment(s *suggestionsv1beta1.Suggestion) (*appsv1.D d.Spec.Template.Spec.ServiceAccountName = suggestionConfigData.ServiceAccountName } - // Attach volume to the suggestion pod spec if ResumePolicy = FromVolume - if s.Spec.ResumePolicy == experimentsv1beta1.FromVolume { + // Attach volume to the suggestion pod spec if ResumePolicy = FromVolume or persistentVolumeClaimSpec provided + if !equality.Semantic.DeepEqual(suggestionConfigData.PersistentVolumeSpec, corev1.PersistentVolumeSpec{}) || s.Spec.ResumePolicy == experimentsv1beta1.FromVolume { d.Spec.Template.Spec.Volumes = []corev1.Volume{ { Name: consts.ContainerSuggestionVolumeName, diff --git a/pkg/controller.v1beta1/suggestion/suggestion_controller.go b/pkg/controller.v1beta1/suggestion/suggestion_controller.go index 331d8815bdb..e5aa42e2fa1 100644 --- a/pkg/controller.v1beta1/suggestion/suggestion_controller.go +++ b/pkg/controller.v1beta1/suggestion/suggestion_controller.go @@ -18,10 +18,12 @@ package suggestion import ( "context" + "encoding/json" "fmt" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/equality" "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -192,13 +194,32 @@ func (r *ReconcileSuggestion) ReconcileSuggestion(instance *suggestionsv1beta1.S suggestionNsName := types.NamespacedName{Name: instance.GetName(), Namespace: instance.GetNamespace()} logger := log.WithValues("Suggestion", suggestionNsName) - suggestionConfigData, err := katibconfig.GetSuggestionConfigData(instance.Spec.Algorithm.AlgorithmName, r.Client) + // TODO(a9p): the next few blocks are from config.go::GetSuggestionConfigData, + // this should be pulled out into a utility function if correct + // Get katib config map + configMap := &corev1.ConfigMap{} + suggestionConfigData := katibconfig.SuggestionConfig{} + err := r.Get( + context.TODO(), + types.NamespacedName{Name: consts.KatibConfigMapName, Namespace: consts.DefaultKatibNamespace}, + configMap) if err != nil { return err } + // Try to find suggestion data in config map + config, ok := configMap.Data[consts.LabelSuggestionTag] + if ok { + // Parse suggestion data to map where key = algorithm name, value = SuggestionConfig + suggestionsConfig := map[string]katibconfig.SuggestionConfig{} + if err := json.Unmarshal([]byte(config), &suggestionsConfig); err != nil { + return err + } + // Try to find SuggestionConfig for the algorithm + suggestionConfigData, _ = suggestionsConfig[instance.Spec.Algorithm.AlgorithmName] + } - // If ResumePolicy is FromVolume or persistentVolumeClaimSpec provided, volume is reconciled for suggestion - if suggestionConfigData.persistentVolumeClaimSpec != nil || instance.Spec.ResumePolicy == experimentsv1beta1.FromVolume { + // If ResumePolicy is FromVolume or PersistentVolumeClaimSpec provided, volume is reconciled for suggestion + if !equality.Semantic.DeepEqual(suggestionConfigData.PersistentVolumeSpec, corev1.PersistentVolumeSpec{}) || instance.Spec.ResumePolicy == experimentsv1beta1.FromVolume { pvc, pv, err := r.DesiredVolume(instance) if err != nil { return err @@ -254,8 +275,10 @@ func (r *ReconcileSuggestion) ReconcileSuggestion(instance *suggestionsv1beta1.S } else { msg := "Deployment is ready" instance.MarkSuggestionStatusDeploymentReady(corev1.ConditionTrue, SuggestionDeploymentReady, msg) + // TODO (a9p) this should be in utils, but breaks import due to it being fully-qualified + // instance.setSuggestionSpec(foundDeploy) + instance.Spec.Algorithm.SuggestionSpec = foundDeploy.Spec } - } experiment := &experimentsv1beta1.Experiment{} trials := &trialsv1beta1.TrialList{} diff --git a/pkg/controller.v1beta1/suggestion/suggestionclient/suggestionclient.go b/pkg/controller.v1beta1/suggestion/suggestionclient/suggestionclient.go index 8db6f3b82f3..11f7eac6615 100644 --- a/pkg/controller.v1beta1/suggestion/suggestionclient/suggestionclient.go +++ b/pkg/controller.v1beta1/suggestion/suggestionclient/suggestionclient.go @@ -109,7 +109,7 @@ func (g *General) SyncAssignments( instance.Status.AlgorithmSettings) requestSuggestion := &suggestionapi.GetSuggestionsRequest{ - Experiment: g.ConvertExperiment(filledE), + Experiment: g.ConvertExperiment(instance, filledE), Trials: g.ConvertTrials(ts), CurrentRequestNumber: int32(currentRequestNum), TotalRequestNumber: int32(instance.Spec.Requests), @@ -143,7 +143,7 @@ func (g *General) SyncAssignments( defer cancelEarlyStopping() requestEarlyStopping := &suggestionapi.GetEarlyStoppingRulesRequest{ - Experiment: g.ConvertExperiment(filledE), + Experiment: g.ConvertExperiment(instance, filledE), Trials: g.ConvertTrials(ts), DbManagerAddress: katibmanagerv1beta1.GetDBManagerAddr(), } @@ -216,7 +216,7 @@ func (g *General) ValidateAlgorithmSettings(instance *suggestionsv1beta1.Suggest defer cancel() request := &suggestionapi.ValidateAlgorithmSettingsRequest{ - Experiment: g.ConvertExperiment(e), + Experiment: g.ConvertExperiment(instance, e), } // See https://github.com/grpc/grpc-go/issues/2636 @@ -264,7 +264,7 @@ func (g *General) ValidateEarlyStoppingSettings(instance *suggestionsv1beta1.Sug defer cancel() request := &suggestionapi.ValidateEarlyStoppingSettingsRequest{ - EarlyStopping: g.ConvertExperiment(e).Spec.EarlyStopping, + EarlyStopping: g.ConvertExperiment(instance, e).Spec.EarlyStopping, } // See https://github.com/grpc/grpc-go/issues/2636 @@ -294,13 +294,14 @@ func (g *General) ValidateEarlyStoppingSettings(instance *suggestionsv1beta1.Sug } // ConvertExperiment converts CRD to the GRPC definition. -func (g *General) ConvertExperiment(e *experimentsv1beta1.Experiment) *suggestionapi.Experiment { +func (g *General) ConvertExperiment(s *suggestionsv1beta1.Suggestion, e *experimentsv1beta1.Experiment) *suggestionapi.Experiment { res := &suggestionapi.Experiment{} res.Name = e.Name res.Spec = &suggestionapi.ExperimentSpec{ Algorithm: &suggestionapi.AlgorithmSpec{ AlgorithmName: e.Spec.Algorithm.AlgorithmName, AlgorithmSettings: convertAlgorithmSettings(e.Spec.Algorithm.AlgorithmSettings), + SuggestionSpec: s.Spec.Algorithm.SuggestionSpec.DeepCopy(), }, Objective: &suggestionapi.ObjectiveSpec{ Type: convertObjectiveType(e.Spec.Objective.Type),