diff --git a/pkg/manager/controllers/rollout/rollout_controller.go b/pkg/manager/controllers/rollout/rollout_controller.go index 93d1ccb..30adde1 100644 --- a/pkg/manager/controllers/rollout/rollout_controller.go +++ b/pkg/manager/controllers/rollout/rollout_controller.go @@ -178,7 +178,7 @@ func (r *RolloutReconciler) Reconcile(ctx context.Context, req reconcile.Request } if rollType != "all" && rollType != "" && !strings.Contains(cfg.Name, rollType) { for _, po := range shardingPods[cfg.Name] { - expectedPodRevision[po.Name] = po.Labels["controller-revision-hash"] + expectedPodRevision[po.Name] = po.Labels[appsv1.ControllerRevisionHashLabelKey] } continue } diff --git a/pkg/webhook/pod/pod_mutating_handler.go b/pkg/webhook/pod/pod_mutating_handler.go index 696d373..8c15475 100644 --- a/pkg/webhook/pod/pod_mutating_handler.go +++ b/pkg/webhook/pod/pod_mutating_handler.go @@ -19,11 +19,18 @@ package pod import ( "context" "encoding/json" + "fmt" "net/http" admissionv1 "k8s.io/api/admission/v1" + appsv1 "k8s.io/api/apps/v1" v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/util/strategicpatch" kubeclientset "k8s.io/client-go/kubernetes" + "k8s.io/client-go/kubernetes/scheme" "k8s.io/klog/v2" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" @@ -31,6 +38,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/webhook/admission" "github.com/KusionStack/controller-mesh/pkg/apis/ctrlmesh" + "github.com/KusionStack/controller-mesh/pkg/manager/controllers/rollout" ) type MutatingHandler struct { @@ -90,3 +98,87 @@ func (h *MutatingHandler) InjectDecoder(d *admission.Decoder) error { h.Decoder = d return nil } + +func (h *MutatingHandler) revisionRollOut(ctx context.Context, pod *v1.Pod) (err error) { + podRevision := pod.Labels[appsv1.ControllerRevisionHashLabelKey] + sts := &appsv1.StatefulSet{} + if pod.OwnerReferences == nil || len(pod.OwnerReferences) == 0 { + return fmt.Errorf("illegal owner reference") + } + if pod.OwnerReferences[0].Kind != "StatefulSet" { + return fmt.Errorf("illegal owner reference kind %s", pod.OwnerReferences[0].Kind) + } + + sts, err = h.directKubeClient.AppsV1().StatefulSets(pod.Namespace).Get(ctx, pod.OwnerReferences[0].Name, metav1.GetOptions{}) + if err != nil { + klog.Error(err) + return err + } + if sts.Spec.UpdateStrategy.Type != appsv1.OnDeleteStatefulSetStrategyType { + return nil + } + expectState := rollout.GetExpectedRevision(sts) + if expectState.UpdateRevision == "" || expectState.PodRevision == nil || expectState.PodRevision[pod.Name] == "" { + return + } + expectedRevision := expectState.PodRevision[pod.Name] + if expectedRevision == podRevision { + return + } + + expectRevision, err := h.directKubeClient.AppsV1().ControllerRevisions(pod.Namespace).Get(ctx, expectedRevision, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("cannot find old ControllerRevision %s", expectedRevision) + } + + createRevision, err := h.directKubeClient.AppsV1().ControllerRevisions(pod.Namespace).Get(ctx, podRevision, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("cannot find ControllerRevision %s by pod %s/%s", podRevision, pod.Namespace, pod.Name) + } + + expectedSts := &appsv1.StatefulSet{} + createdSts := &appsv1.StatefulSet{} + + applyPatch(expectedSts, &expectRevision.Data.Raw) + applyPatch(createdSts, &createRevision.Data.Raw) + + expectedPo := &v1.Pod{ + Spec: expectedSts.Spec.Template.Spec, + } + createdPo := &v1.Pod{ + Spec: createdSts.Spec.Template.Spec, + } + + expectedBt, _ := runtime.Encode(patchCodec, expectedPo) + createdBt, _ := runtime.Encode(patchCodec, createdPo) + currentBt, _ := runtime.Encode(patchCodec, pod) + + patch, err := strategicpatch.CreateTwoWayMergePatch(createdBt, expectedBt, expectedPo) + if err != nil { + return err + } + originBt, err := strategicpatch.StrategicMergePatch(currentBt, patch, pod) + if err != nil { + return err + } + newPod := &v1.Pod{} + if err = json.Unmarshal(originBt, newPod); err != nil { + return err + } + pod.Spec = newPod.Spec + pod.Labels[appsv1.ControllerRevisionHashLabelKey] = expectedRevision + return +} + +var patchCodec = scheme.Codecs.LegacyCodec(schema.GroupVersion{Group: "apps", Version: "v1"}, schema.GroupVersion{Version: "v1"}) + +func applyPatch(target runtime.Object, podPatch *[]byte) error { + patched, err := strategicpatch.StrategicMergePatch([]byte(runtime.EncodeOrDie(patchCodec, target)), *podPatch, target) + if err != nil { + return err + } + if err = json.Unmarshal(patched, target); err != nil { + return err + } + return nil +}