From 893b704d8bee07f2b896a2248e59596ae53e0f5c Mon Sep 17 00:00:00 2001 From: "Fabio M. Graetz, Ph.D." Date: Thu, 27 Feb 2025 20:39:06 +0100 Subject: [PATCH] Fix: Enable kubeflow plugins to use dynamic log links (#6284) Signed-off-by: Fabio Graetz --- .../k8s/kfoperators/common/common_operator.go | 7 ++- .../common/common_operator_test.go | 62 +++++++++++++++++-- .../tasks/plugins/k8s/kfoperators/mpi/mpi.go | 9 ++- .../plugins/k8s/kfoperators/mpi/mpi_test.go | 5 +- .../k8s/kfoperators/pytorch/pytorch.go | 9 ++- .../k8s/kfoperators/pytorch/pytorch_test.go | 10 +-- .../k8s/kfoperators/tensorflow/tensorflow.go | 9 ++- .../kfoperators/tensorflow/tensorflow_test.go | 5 +- 8 files changed, 97 insertions(+), 19 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index 18b2aa1449..53e54d081c 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -89,7 +89,7 @@ func GetMPIPhaseInfo(currentCondition commonOp.JobCondition, occurredAt time.Tim } // GetLogs will return the logs for kubeflow job -func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v1.ObjectMeta, hasMaster bool, +func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v1.ObjectMeta, taskTemplate *core.TaskTemplate, hasMaster bool, workersCount int32, psReplicasCount int32, chiefReplicasCount int32, evaluatorReplicasCount int32) ([]*core.TaskLog, error) { name := objectMeta.Name namespace := objectMeta.Namespace @@ -125,6 +125,7 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v PodUnixStartTime: startTime, PodUnixFinishTime: finishTime, TaskExecutionID: taskExecID, + TaskTemplate: taskTemplate, }, ) if masterErr != nil { @@ -143,6 +144,7 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v PodUnixStartTime: startTime, PodUnixFinishTime: finishTime, TaskExecutionID: taskExecID, + TaskTemplate: taskTemplate, }) if err != nil { return nil, err @@ -160,6 +162,7 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex), Namespace: namespace, TaskExecutionID: taskExecID, + TaskTemplate: taskTemplate, }) if err != nil { return nil, err @@ -172,6 +175,7 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v PodName: name + fmt.Sprintf("-chiefReplica-%d", 0), Namespace: namespace, TaskExecutionID: taskExecID, + TaskTemplate: taskTemplate, }) if err != nil { return nil, err @@ -184,6 +188,7 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v PodName: name + fmt.Sprintf("-evaluatorReplica-%d", 0), Namespace: namespace, TaskExecutionID: taskExecID, + TaskTemplate: taskTemplate, }) if err != nil { return nil, err diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go index 73c68c4d83..0861abff62 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go @@ -8,6 +8,7 @@ import ( commonOp "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/structpb" corev1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" @@ -17,6 +18,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" ) func TestMain(m *testing.M) { @@ -170,12 +172,13 @@ func TestGetLogs(t *testing.T) { workers := int32(1) launcher := int32(1) + taskTemplate := dummyTaskTemplate() taskCtx := dummyTaskContext() mpiJobObjectMeta := meta_v1.ObjectMeta{ Name: "test", Namespace: "mpi-namespace", } - jobLogs, err := GetLogs(taskCtx, MPITaskType, mpiJobObjectMeta, false, workers, launcher, 0, 0) + jobLogs, err := GetLogs(taskCtx, MPITaskType, mpiJobObjectMeta, taskTemplate, false, workers, launcher, 0, 0) assert.NoError(t, err) assert.Equal(t, 1, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", "mpi-namespace", "test"), jobLogs[0].GetUri()) @@ -184,7 +187,7 @@ func TestGetLogs(t *testing.T) { Name: "test", Namespace: "pytorch-namespace", } - jobLogs, err = GetLogs(taskCtx, PytorchTaskType, pytorchJobObjectMeta, true, workers, launcher, 0, 0) + jobLogs, err = GetLogs(taskCtx, PytorchTaskType, pytorchJobObjectMeta, taskTemplate, true, workers, launcher, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", "pytorch-namespace", "test"), jobLogs[0].GetUri()) @@ -194,7 +197,7 @@ func TestGetLogs(t *testing.T) { Name: "test", Namespace: "tensorflow-namespace", } - jobLogs, err = GetLogs(taskCtx, TensorflowTaskType, tensorflowJobObjectMeta, false, workers, launcher, 1, 0) + jobLogs, err = GetLogs(taskCtx, TensorflowTaskType, tensorflowJobObjectMeta, taskTemplate, false, workers, launcher, 1, 0) assert.NoError(t, err) assert.Equal(t, 3, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=tensorflow-namespace", "tensorflow-namespace", "test"), jobLogs[0].GetUri()) @@ -209,6 +212,7 @@ func TestGetLogsTemplateUri(t *testing.T) { StackDriverTemplateURI: "https://console.cloud.google.com/logs/query;query=resource.labels.pod_name={{.podName}}×tamp>{{.podRFC3339StartTime}}", })) + taskTemplate := dummyTaskTemplate() taskCtx := dummyTaskContext() pytorchJobObjectMeta := meta_v1.ObjectMeta{ Name: "test", @@ -218,13 +222,44 @@ func TestGetLogsTemplateUri(t *testing.T) { Time: time.Date(2022, time.January, 1, 12, 0, 0, 0, time.UTC), }, } - jobLogs, err := GetLogs(taskCtx, PytorchTaskType, pytorchJobObjectMeta, true, 1, 0, 0, 0) + jobLogs, err := GetLogs(taskCtx, PytorchTaskType, pytorchJobObjectMeta, taskTemplate, true, 1, 0, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) assert.Equal(t, fmt.Sprintf("https://console.cloud.google.com/logs/query;query=resource.labels.pod_name=%s-master-0×tamp>%s", "test", "2022-01-01T12:00:00Z"), jobLogs[0].GetUri()) assert.Equal(t, fmt.Sprintf("https://console.cloud.google.com/logs/query;query=resource.labels.pod_name=%s-worker-0×tamp>%s", "test", "2022-01-01T12:00:00Z"), jobLogs[1].GetUri()) } +func TestGetLogsDynamic(t *testing.T) { + dynamicLinks := map[string]tasklog.TemplateLogPlugin{ + "test-dynamic-link": { + TemplateURIs: []string{"https://some-service.com/{{.taskConfig.dynamicParam}}"}, + }, + } + + assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{ + DynamicLogLinks: dynamicLinks, + })) + + taskTemplate := dummyTaskTemplate() + taskTemplate.Config = map[string]string{ + "link_type": "test-dynamic-link", + "dynamicParam": "dynamic-value", + } + taskCtx := dummyTaskContext() + pytorchJobObjectMeta := meta_v1.ObjectMeta{ + Name: "test", + Namespace: "pytorch-" + + "namespace", + CreationTimestamp: meta_v1.Time{ + Time: time.Date(2022, time.January, 1, 12, 0, 0, 0, time.UTC), + }, + } + jobLogs, err := GetLogs(taskCtx, PytorchTaskType, pytorchJobObjectMeta, taskTemplate, true, 1, 0, 0, 0) + assert.NoError(t, err) + assert.Equal(t, 2, len(jobLogs)) + assert.Equal(t, "https://some-service.com/dynamic-value", jobLogs[0].GetUri()) +} + func dummyPodSpec() v1.PodSpec { return v1.PodSpec{ Containers: []v1.Container{ @@ -297,6 +332,25 @@ func TestOverrideContainerSpecEmptyFields(t *testing.T) { assert.Equal(t, []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"}, podSpec.Containers[0].Args) } +func dummyTaskTemplate() *core.TaskTemplate { + id := "dummy-id" + + testImage := "dummy-image" + + structObj := structpb.Struct{} + + return &core.TaskTemplate{ + Id: &core.Identifier{Name: id}, + Type: "container", + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: testImage, + }, + }, + Custom: &structObj, + } +} + func dummyTaskContext() pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 7ba2c0cb86..68780752dc 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -155,17 +155,22 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu // Analyzes the k8s resource and reports the status as TaskPhase. This call is expected to be relatively fast, // any operations that might take a long time (limits are configured system-wide) should be offloaded to the // background. -func (mpiOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { +func (mpiOperatorResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { var numWorkers, numLauncherReplicas *int32 app, ok := resource.(*kubeflowv1.MPIJob) if !ok { return pluginsCore.PhaseInfoUndefined, fmt.Errorf("failed to convert resource data type") } + taskTemplate, err := pluginContext.TaskReader().Read(ctx) + if err != nil { + return pluginsCore.PhaseInfoUndefined, err + } + numWorkers = common.GetReplicaCount(app.Spec.MPIReplicaSpecs, kubeflowv1.MPIJobReplicaTypeWorker) numLauncherReplicas = common.GetReplicaCount(app.Spec.MPIReplicaSpecs, kubeflowv1.MPIJobReplicaTypeLauncher) - taskLogs, err := common.GetLogs(pluginContext, common.MPITaskType, app.ObjectMeta, false, + taskLogs, err := common.GetLogs(pluginContext, common.MPITaskType, app.ObjectMeta, taskTemplate, false, *numWorkers, *numLauncherReplicas, 0, 0) if err != nil { return pluginsCore.PhaseInfoUndefined, err diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index a7c5f74366..b6fcd98973 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -577,8 +577,9 @@ func TestGetLogs(t *testing.T) { mpiResourceHandler := mpiOperatorResourceHandler{} mpiJob := dummyMPIJobResource(mpiResourceHandler, workers, launcher, slots, mpiOp.JobRunning) - taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(workers, launcher, slots)), resourceRequirements, nil, k8s.PluginState{}) - jobLogs, err := common.GetLogs(taskCtx, common.MPITaskType, mpiJob.ObjectMeta, false, workers, launcher, 0, 0) + taskTemplate := dummyMPITaskTemplate("", dummyMPICustomObj(workers, launcher, slots)) + taskCtx := dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}) + jobLogs, err := common.GetLogs(taskCtx, common.MPITaskType, mpiJob.ObjectMeta, taskTemplate, false, workers, launcher, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", jobNamespace, jobName), jobLogs[0].GetUri()) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index ce1aca9ee6..8771b773af 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -204,7 +204,7 @@ func ParseElasticConfig(elasticConfig ElasticConfig) *kubeflowv1.ElasticPolicy { // Analyses the k8s resource and reports the status as TaskPhase. This call is expected to be relatively fast, // any operations that might take a long time (limits are configured system-wide) should be offloaded to the // background. -func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { +func (pytorchOperatorResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { app, ok := resource.(*kubeflowv1.PyTorchJob) if !ok { return pluginsCore.PhaseInfoUndefined, fmt.Errorf("failed to convert resource data type") @@ -218,7 +218,12 @@ func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginCont workersCount := common.GetReplicaCount(app.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeWorker) - taskLogs, err := common.GetLogs(pluginContext, common.PytorchTaskType, app.ObjectMeta, hasMaster, *workersCount, 0, 0, 0) + taskTemplate, err := pluginContext.TaskReader().Read(ctx) + if err != nil { + return pluginsCore.PhaseInfoUndefined, err + } + + taskLogs, err := common.GetLogs(pluginContext, common.PytorchTaskType, app.ObjectMeta, taskTemplate, hasMaster, *workersCount, 0, 0, 0) if err != nil { return pluginsCore.PhaseInfoUndefined, err } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index a662be404b..eb9003718e 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -697,8 +697,9 @@ func TestGetLogs(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) - taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "", k8s.PluginState{}) - jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0, 0) + taskTemplate := dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)) + taskCtx := dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{}) + jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, taskTemplate, hasMaster, workers, 0, 0, 0) assert.NoError(t, err) assert.Equal(t, 3, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[0].GetUri()) @@ -717,8 +718,9 @@ func TestGetLogsElastic(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) - taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "", k8s.PluginState{}) - jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0, 0) + taskTemplate := dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)) + taskCtx := dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{}) + jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, taskTemplate, hasMaster, workers, 0, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[0].GetUri()) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index 3c0a3e9485..7f58a8dce1 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -151,18 +151,23 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task // Analyses the k8s resource and reports the status as TaskPhase. This call is expected to be relatively fast, // any operations that might take a long time (limits are configured system-wide) should be offloaded to the // background. -func (tensorflowOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { +func (tensorflowOperatorResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { app, ok := resource.(*kubeflowv1.TFJob) if !ok { return pluginsCore.PhaseInfoUndefined, fmt.Errorf("failed to convert resource data type") } + taskTemplate, err := pluginContext.TaskReader().Read(ctx) + if err != nil { + return pluginsCore.PhaseInfoUndefined, err + } + workersCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeWorker) psReplicasCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypePS) chiefCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeChief) evaluatorReplicasCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeEval) - taskLogs, err := common.GetLogs(pluginContext, common.TensorflowTaskType, app.ObjectMeta, false, + taskLogs, err := common.GetLogs(pluginContext, common.TensorflowTaskType, app.ObjectMeta, taskTemplate, false, *workersCount, *psReplicasCount, *chiefCount, *evaluatorReplicasCount) if err != nil { return pluginsCore.PhaseInfoUndefined, err diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 1ced902dcb..76e3d9938e 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -624,8 +624,9 @@ func TestGetLogs(t *testing.T) { tensorflowResourceHandler := tensorflowOperatorResourceHandler{} tensorFlowJob := dummyTensorFlowJobResource(tensorflowResourceHandler, workers, psReplicas, chiefReplicas, evaluatorReplicas, commonOp.JobRunning) - taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas, evaluatorReplicas)), resourceRequirements, nil, k8s.PluginState{}) - jobLogs, err := common.GetLogs(taskCtx, common.TensorflowTaskType, tensorFlowJob.ObjectMeta, false, + taskTemplate := dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas, evaluatorReplicas)) + taskCtx := dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}) + jobLogs, err := common.GetLogs(taskCtx, common.TensorflowTaskType, tensorFlowJob.ObjectMeta, taskTemplate, false, workers, psReplicas, chiefReplicas, evaluatorReplicas) assert.NoError(t, err) assert.Equal(t, 5, len(jobLogs))