Skip to content

Commit

Permalink
Fix: Enable kubeflow plugins to use dynamic log links (#6284)
Browse files Browse the repository at this point in the history
Signed-off-by: Fabio Graetz <[email protected]>
  • Loading branch information
fg91 authored Feb 27, 2025
1 parent 5fe7a73 commit 893b704
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func GetMPIPhaseInfo(currentCondition commonOp.JobCondition, occurredAt time.Tim
}

// GetLogs will return the logs for kubeflow job
func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v1.ObjectMeta, hasMaster bool,
func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v1.ObjectMeta, taskTemplate *core.TaskTemplate, hasMaster bool,
workersCount int32, psReplicasCount int32, chiefReplicasCount int32, evaluatorReplicasCount int32) ([]*core.TaskLog, error) {
name := objectMeta.Name
namespace := objectMeta.Namespace
Expand Down Expand Up @@ -125,6 +125,7 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v
PodUnixStartTime: startTime,
PodUnixFinishTime: finishTime,
TaskExecutionID: taskExecID,
TaskTemplate: taskTemplate,
},
)
if masterErr != nil {
Expand All @@ -143,6 +144,7 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v
PodUnixStartTime: startTime,
PodUnixFinishTime: finishTime,
TaskExecutionID: taskExecID,
TaskTemplate: taskTemplate,
})
if err != nil {
return nil, err
Expand All @@ -160,6 +162,7 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v
PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex),
Namespace: namespace,
TaskExecutionID: taskExecID,
TaskTemplate: taskTemplate,
})
if err != nil {
return nil, err
Expand All @@ -172,6 +175,7 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v
PodName: name + fmt.Sprintf("-chiefReplica-%d", 0),
Namespace: namespace,
TaskExecutionID: taskExecID,
TaskTemplate: taskTemplate,
})
if err != nil {
return nil, err
Expand All @@ -184,6 +188,7 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v
PodName: name + fmt.Sprintf("-evaluatorReplica-%d", 0),
Namespace: namespace,
TaskExecutionID: taskExecID,
TaskTemplate: taskTemplate,
})
if err != nil {
return nil, err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

commonOp "github.com/kubeflow/common/pkg/apis/common/v1"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/structpb"
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
Expand All @@ -17,6 +18,7 @@ import (
"github.com/flyteorg/flyte/flyteplugins/go/tasks/logs"
pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog"
)

func TestMain(m *testing.M) {
Expand Down Expand Up @@ -170,12 +172,13 @@ func TestGetLogs(t *testing.T) {
workers := int32(1)
launcher := int32(1)

taskTemplate := dummyTaskTemplate()
taskCtx := dummyTaskContext()
mpiJobObjectMeta := meta_v1.ObjectMeta{
Name: "test",
Namespace: "mpi-namespace",
}
jobLogs, err := GetLogs(taskCtx, MPITaskType, mpiJobObjectMeta, false, workers, launcher, 0, 0)
jobLogs, err := GetLogs(taskCtx, MPITaskType, mpiJobObjectMeta, taskTemplate, false, workers, launcher, 0, 0)
assert.NoError(t, err)
assert.Equal(t, 1, len(jobLogs))
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", "mpi-namespace", "test"), jobLogs[0].GetUri())
Expand All @@ -184,7 +187,7 @@ func TestGetLogs(t *testing.T) {
Name: "test",
Namespace: "pytorch-namespace",
}
jobLogs, err = GetLogs(taskCtx, PytorchTaskType, pytorchJobObjectMeta, true, workers, launcher, 0, 0)
jobLogs, err = GetLogs(taskCtx, PytorchTaskType, pytorchJobObjectMeta, taskTemplate, true, workers, launcher, 0, 0)
assert.NoError(t, err)
assert.Equal(t, 2, len(jobLogs))
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", "pytorch-namespace", "test"), jobLogs[0].GetUri())
Expand All @@ -194,7 +197,7 @@ func TestGetLogs(t *testing.T) {
Name: "test",
Namespace: "tensorflow-namespace",
}
jobLogs, err = GetLogs(taskCtx, TensorflowTaskType, tensorflowJobObjectMeta, false, workers, launcher, 1, 0)
jobLogs, err = GetLogs(taskCtx, TensorflowTaskType, tensorflowJobObjectMeta, taskTemplate, false, workers, launcher, 1, 0)
assert.NoError(t, err)
assert.Equal(t, 3, len(jobLogs))
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=tensorflow-namespace", "tensorflow-namespace", "test"), jobLogs[0].GetUri())
Expand All @@ -209,6 +212,7 @@ func TestGetLogsTemplateUri(t *testing.T) {
StackDriverTemplateURI: "https://console.cloud.google.com/logs/query;query=resource.labels.pod_name={{.podName}}&timestamp>{{.podRFC3339StartTime}}",
}))

taskTemplate := dummyTaskTemplate()
taskCtx := dummyTaskContext()
pytorchJobObjectMeta := meta_v1.ObjectMeta{
Name: "test",
Expand All @@ -218,13 +222,44 @@ func TestGetLogsTemplateUri(t *testing.T) {
Time: time.Date(2022, time.January, 1, 12, 0, 0, 0, time.UTC),
},
}
jobLogs, err := GetLogs(taskCtx, PytorchTaskType, pytorchJobObjectMeta, true, 1, 0, 0, 0)
jobLogs, err := GetLogs(taskCtx, PytorchTaskType, pytorchJobObjectMeta, taskTemplate, true, 1, 0, 0, 0)
assert.NoError(t, err)
assert.Equal(t, 2, len(jobLogs))
assert.Equal(t, fmt.Sprintf("https://console.cloud.google.com/logs/query;query=resource.labels.pod_name=%s-master-0&timestamp>%s", "test", "2022-01-01T12:00:00Z"), jobLogs[0].GetUri())
assert.Equal(t, fmt.Sprintf("https://console.cloud.google.com/logs/query;query=resource.labels.pod_name=%s-worker-0&timestamp>%s", "test", "2022-01-01T12:00:00Z"), jobLogs[1].GetUri())
}

func TestGetLogsDynamic(t *testing.T) {
dynamicLinks := map[string]tasklog.TemplateLogPlugin{
"test-dynamic-link": {
TemplateURIs: []string{"https://some-service.com/{{.taskConfig.dynamicParam}}"},
},
}

assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{
DynamicLogLinks: dynamicLinks,
}))

taskTemplate := dummyTaskTemplate()
taskTemplate.Config = map[string]string{
"link_type": "test-dynamic-link",
"dynamicParam": "dynamic-value",
}
taskCtx := dummyTaskContext()
pytorchJobObjectMeta := meta_v1.ObjectMeta{
Name: "test",
Namespace: "pytorch-" +
"namespace",
CreationTimestamp: meta_v1.Time{
Time: time.Date(2022, time.January, 1, 12, 0, 0, 0, time.UTC),
},
}
jobLogs, err := GetLogs(taskCtx, PytorchTaskType, pytorchJobObjectMeta, taskTemplate, true, 1, 0, 0, 0)
assert.NoError(t, err)
assert.Equal(t, 2, len(jobLogs))
assert.Equal(t, "https://some-service.com/dynamic-value", jobLogs[0].GetUri())
}

func dummyPodSpec() v1.PodSpec {
return v1.PodSpec{
Containers: []v1.Container{
Expand Down Expand Up @@ -297,6 +332,25 @@ func TestOverrideContainerSpecEmptyFields(t *testing.T) {
assert.Equal(t, []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"}, podSpec.Containers[0].Args)
}

func dummyTaskTemplate() *core.TaskTemplate {
id := "dummy-id"

testImage := "dummy-image"

structObj := structpb.Struct{}

return &core.TaskTemplate{
Id: &core.Identifier{Name: id},
Type: "container",
Target: &core.TaskTemplate_Container{
Container: &core.Container{
Image: testImage,
},
},
Custom: &structObj,
}
}

func dummyTaskContext() pluginsCore.TaskExecutionContext {
taskCtx := &mocks.TaskExecutionContext{}

Expand Down
9 changes: 7 additions & 2 deletions flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,17 +155,22 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu
// Analyzes the k8s resource and reports the status as TaskPhase. This call is expected to be relatively fast,
// any operations that might take a long time (limits are configured system-wide) should be offloaded to the
// background.
func (mpiOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
func (mpiOperatorResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
var numWorkers, numLauncherReplicas *int32
app, ok := resource.(*kubeflowv1.MPIJob)
if !ok {
return pluginsCore.PhaseInfoUndefined, fmt.Errorf("failed to convert resource data type")
}

taskTemplate, err := pluginContext.TaskReader().Read(ctx)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}

numWorkers = common.GetReplicaCount(app.Spec.MPIReplicaSpecs, kubeflowv1.MPIJobReplicaTypeWorker)
numLauncherReplicas = common.GetReplicaCount(app.Spec.MPIReplicaSpecs, kubeflowv1.MPIJobReplicaTypeLauncher)

taskLogs, err := common.GetLogs(pluginContext, common.MPITaskType, app.ObjectMeta, false,
taskLogs, err := common.GetLogs(pluginContext, common.MPITaskType, app.ObjectMeta, taskTemplate, false,
*numWorkers, *numLauncherReplicas, 0, 0)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
Expand Down
5 changes: 3 additions & 2 deletions flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -577,8 +577,9 @@ func TestGetLogs(t *testing.T) {

mpiResourceHandler := mpiOperatorResourceHandler{}
mpiJob := dummyMPIJobResource(mpiResourceHandler, workers, launcher, slots, mpiOp.JobRunning)
taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(workers, launcher, slots)), resourceRequirements, nil, k8s.PluginState{})
jobLogs, err := common.GetLogs(taskCtx, common.MPITaskType, mpiJob.ObjectMeta, false, workers, launcher, 0, 0)
taskTemplate := dummyMPITaskTemplate("", dummyMPICustomObj(workers, launcher, slots))
taskCtx := dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})
jobLogs, err := common.GetLogs(taskCtx, common.MPITaskType, mpiJob.ObjectMeta, taskTemplate, false, workers, launcher, 0, 0)
assert.NoError(t, err)
assert.Equal(t, 2, len(jobLogs))
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", jobNamespace, jobName), jobLogs[0].GetUri())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func ParseElasticConfig(elasticConfig ElasticConfig) *kubeflowv1.ElasticPolicy {
// Analyses the k8s resource and reports the status as TaskPhase. This call is expected to be relatively fast,
// any operations that might take a long time (limits are configured system-wide) should be offloaded to the
// background.
func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
func (pytorchOperatorResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
app, ok := resource.(*kubeflowv1.PyTorchJob)
if !ok {
return pluginsCore.PhaseInfoUndefined, fmt.Errorf("failed to convert resource data type")
Expand All @@ -218,7 +218,12 @@ func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginCont

workersCount := common.GetReplicaCount(app.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeWorker)

taskLogs, err := common.GetLogs(pluginContext, common.PytorchTaskType, app.ObjectMeta, hasMaster, *workersCount, 0, 0, 0)
taskTemplate, err := pluginContext.TaskReader().Read(ctx)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}

taskLogs, err := common.GetLogs(pluginContext, common.PytorchTaskType, app.ObjectMeta, taskTemplate, hasMaster, *workersCount, 0, 0, 0)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -697,8 +697,9 @@ func TestGetLogs(t *testing.T) {

pytorchResourceHandler := pytorchOperatorResourceHandler{}
pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning)
taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "", k8s.PluginState{})
jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0, 0)
taskTemplate := dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers))
taskCtx := dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})
jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, taskTemplate, hasMaster, workers, 0, 0, 0)
assert.NoError(t, err)
assert.Equal(t, 3, len(jobLogs))
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[0].GetUri())
Expand All @@ -717,8 +718,9 @@ func TestGetLogsElastic(t *testing.T) {

pytorchResourceHandler := pytorchOperatorResourceHandler{}
pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning)
taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "", k8s.PluginState{})
jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0, 0)
taskTemplate := dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers))
taskCtx := dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})
jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, taskTemplate, hasMaster, workers, 0, 0, 0)
assert.NoError(t, err)
assert.Equal(t, 2, len(jobLogs))
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[0].GetUri())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,18 +151,23 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task
// Analyses the k8s resource and reports the status as TaskPhase. This call is expected to be relatively fast,
// any operations that might take a long time (limits are configured system-wide) should be offloaded to the
// background.
func (tensorflowOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
func (tensorflowOperatorResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
app, ok := resource.(*kubeflowv1.TFJob)
if !ok {
return pluginsCore.PhaseInfoUndefined, fmt.Errorf("failed to convert resource data type")
}

taskTemplate, err := pluginContext.TaskReader().Read(ctx)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}

workersCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeWorker)
psReplicasCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypePS)
chiefCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeChief)
evaluatorReplicasCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeEval)

taskLogs, err := common.GetLogs(pluginContext, common.TensorflowTaskType, app.ObjectMeta, false,
taskLogs, err := common.GetLogs(pluginContext, common.TensorflowTaskType, app.ObjectMeta, taskTemplate, false,
*workersCount, *psReplicasCount, *chiefCount, *evaluatorReplicasCount)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,9 @@ func TestGetLogs(t *testing.T) {

tensorflowResourceHandler := tensorflowOperatorResourceHandler{}
tensorFlowJob := dummyTensorFlowJobResource(tensorflowResourceHandler, workers, psReplicas, chiefReplicas, evaluatorReplicas, commonOp.JobRunning)
taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas, evaluatorReplicas)), resourceRequirements, nil, k8s.PluginState{})
jobLogs, err := common.GetLogs(taskCtx, common.TensorflowTaskType, tensorFlowJob.ObjectMeta, false,
taskTemplate := dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas, evaluatorReplicas))
taskCtx := dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})
jobLogs, err := common.GetLogs(taskCtx, common.TensorflowTaskType, tensorFlowJob.ObjectMeta, taskTemplate, false,
workers, psReplicas, chiefReplicas, evaluatorReplicas)
assert.NoError(t, err)
assert.Equal(t, 5, len(jobLogs))
Expand Down

0 comments on commit 893b704

Please sign in to comment.