From 26bd912857ae23e5a07336bc3bd7492900cdb12e Mon Sep 17 00:00:00 2001 From: "Yuan (Bob) Gong" <4957653+Bobgy@users.noreply.github.com> Date: Thu, 18 Mar 2021 02:33:15 +0800 Subject: [PATCH] chore(v2/frontend): integrate ML Metadata tab with v2 metadata (#5308) * chore(v2/frontend): integrate ML Metadata tab with v2 metadata * fix test * test: add unit tests for getContextByTypeAndName --- frontend/src/{TestUtils.tsx => TestUtils.ts} | 16 ++++ frontend/src/lib/MlmdUtils.test.ts | 86 ++++++++++++++++++++ frontend/src/lib/MlmdUtils.ts | 29 ++++++- frontend/src/pages/RunDetails.test.tsx | 10 +-- frontend/src/pages/RunDetails.tsx | 28 ++----- 5 files changed, 140 insertions(+), 29 deletions(-) rename frontend/src/{TestUtils.tsx => TestUtils.ts} (92%) create mode 100644 frontend/src/lib/MlmdUtils.test.ts diff --git a/frontend/src/TestUtils.tsx b/frontend/src/TestUtils.ts similarity index 92% rename from frontend/src/TestUtils.tsx rename to frontend/src/TestUtils.ts index b132803bdce..c24206b892a 100644 --- a/frontend/src/TestUtils.tsx +++ b/frontend/src/TestUtils.ts @@ -27,6 +27,7 @@ import { mount, ReactWrapper } from 'enzyme'; import { object } from 'prop-types'; import { format } from 'prettier'; import snapshotDiff from 'snapshot-diff'; +import { logger } from './lib/Utils'; export default class TestUtils { /** @@ -152,3 +153,18 @@ export function diff({ export function formatHTML(html: string): string { return format(html, { parser: 'html' }); } + +export function expectWarnings() { + const loggerWarningSpy = jest.spyOn(logger, 'warn'); + loggerWarningSpy.mockImplementation(); + return () => { + expect(loggerWarningSpy).toHaveBeenCalled(); + }; +} + +export function testBestPractices() { + beforeEach(async () => { + jest.resetAllMocks(); + jest.restoreAllMocks(); + }); +} diff --git a/frontend/src/lib/MlmdUtils.test.ts b/frontend/src/lib/MlmdUtils.test.ts new file mode 100644 index 00000000000..abd58ee8aec --- /dev/null +++ b/frontend/src/lib/MlmdUtils.test.ts @@ -0,0 +1,86 @@ +import { + Api, + Context, + GetContextByTypeAndNameRequest, + GetContextByTypeAndNameResponse, +} from '@kubeflow/frontend'; +import { expectWarnings, testBestPractices } from 'src/TestUtils'; +import { Workflow, WorkflowSpec, WorkflowStatus } from 'third_party/argo-ui/argo_template'; +import { getRunContext } from './MlmdUtils'; + +testBestPractices(); + +const WORKFLOW_NAME = 'run-st448'; +const WORKFLOW_EMPTY: Workflow = { + metadata: { + name: WORKFLOW_NAME, + }, + // there are many unrelated fields here, omit them + spec: {} as WorkflowSpec, + status: {} as WorkflowStatus, +}; + +const V2_CONTEXT = new Context(); +V2_CONTEXT.setName(WORKFLOW_NAME); +V2_CONTEXT.setType('kfp.PipelineRun'); + +const TFX_CONTEXT = new Context(); +TFX_CONTEXT.setName('run.run-st448'); +TFX_CONTEXT.setType('run'); + +const V1_CONTEXT = new Context(); +V1_CONTEXT.setName(WORKFLOW_NAME); +V1_CONTEXT.setType('KfpRun'); + +describe('MlmdUtils', () => { + describe('getRunContext', () => { + it('gets KFP v2 context', async () => { + mockGetContextByTypeAndName([V2_CONTEXT]); + const context = await getRunContext({ + ...WORKFLOW_EMPTY, + metadata: { + ...WORKFLOW_EMPTY.metadata, + annotations: { 'pipelines.kubeflow.org/v2_pipeline': 'true' }, + }, + }); + expect(context).toEqual(V2_CONTEXT); + }); + + it('gets TFX context', async () => { + mockGetContextByTypeAndName([TFX_CONTEXT, V1_CONTEXT]); + const context = await getRunContext(WORKFLOW_EMPTY); + expect(context).toEqual(TFX_CONTEXT); + }); + + it('gets KFP v1 context', async () => { + const verify = expectWarnings(); + mockGetContextByTypeAndName([V1_CONTEXT]); + const context = await getRunContext(WORKFLOW_EMPTY); + expect(context).toEqual(V1_CONTEXT); + verify(); + }); + + it('throws error when not found', async () => { + const verify = expectWarnings(); + mockGetContextByTypeAndName([]); + await expect(getRunContext(WORKFLOW_EMPTY)).rejects.toThrow(); + verify(); + }); + }); +}); + +function mockGetContextByTypeAndName(contexts: Context[]) { + const getContextByTypeAndNameSpy = jest.spyOn( + Api.getInstance().metadataStoreService, + 'getContextByTypeAndName', + ); + getContextByTypeAndNameSpy.mockImplementation((req: GetContextByTypeAndNameRequest) => { + const response = new GetContextByTypeAndNameResponse(); + const found = contexts.find( + context => + context.getType() === req.getTypeName() && context.getName() === req.getContextName(), + ); + response.setContext(found); + return response; + }); +} diff --git a/frontend/src/lib/MlmdUtils.ts b/frontend/src/lib/MlmdUtils.ts index c0afafc3686..a337a7ba9c4 100644 --- a/frontend/src/lib/MlmdUtils.ts +++ b/frontend/src/lib/MlmdUtils.ts @@ -11,8 +11,16 @@ import { GetContextByTypeAndNameRequest, GetExecutionsByContextRequest, } from '@kubeflow/frontend/src/mlmd/generated/ml_metadata/proto/metadata_store_service_pb'; +import { Workflow } from 'third_party/argo-ui/argo_template'; +import { logger } from './Utils'; async function getContext({ type, name }: { type: string; name: string }): Promise { + if (type === '') { + throw new Error('Failed to getContext: type is empty.'); + } + if (name === '') { + throw new Error('Failed to getContext: name is empty.'); + } const request = new GetContextByTypeAndNameRequest(); request.setTypeName(type); request.setContextName(name); @@ -32,7 +40,7 @@ async function getContext({ type, name }: { type: string; name: string }): Promi /** * @throws error when network error, or not found */ -export async function getTfxRunContext(argoWorkflowName: string): Promise { +async function getTfxRunContext(argoWorkflowName: string): Promise { // argoPodName has the general form "pipelineName-workflowId-executionId". // All components of a pipeline within a single run will have the same // "pipelineName-workflowId" prefix. @@ -49,10 +57,27 @@ export async function getTfxRunContext(argoWorkflowName: string): Promise { +async function getKfpRunContext(argoWorkflowName: string): Promise { return await getContext({ name: argoWorkflowName, type: 'KfpRun' }); } +async function getKfpV2RunContext(argoWorkflowName: string): Promise { + return await getContext({ name: argoWorkflowName, type: 'kfp.PipelineRun' }); +} + +export async function getRunContext(workflow: Workflow): Promise { + const workflowName = workflow?.metadata?.name || ''; + if (workflow?.metadata?.annotations?.['pipelines.kubeflow.org/v2_pipeline'] === 'true') { + return await getKfpV2RunContext(workflowName); + } + try { + return await getTfxRunContext(workflowName); + } catch (err) { + logger.warn(`Cannot find tfx run context (this is expected for non tfx runs)`, err); + return await getKfpRunContext(workflowName); + } +} + /** * @throws error when network error */ diff --git a/frontend/src/pages/RunDetails.test.tsx b/frontend/src/pages/RunDetails.test.tsx index 3c6db518c6e..a7eab80384c 100644 --- a/frontend/src/pages/RunDetails.test.tsx +++ b/frontend/src/pages/RunDetails.test.tsx @@ -93,8 +93,7 @@ describe('RunDetails', () => { let terminateRunSpy: any; let artifactTypesSpy: any; let formatDateStringSpy: any; - let getTfxRunContextSpy: any; - let getKfpRunContextSpy: any; + let getRunContextSpy: any; let warnSpy: any; let testRun: ApiRunDetail = {}; @@ -158,11 +157,8 @@ describe('RunDetails', () => { // We mock this because it uses toLocaleDateString, which causes mismatches between local and CI // test environments formatDateStringSpy = jest.spyOn(Utils, 'formatDateString'); - getTfxRunContextSpy = jest.spyOn(MlmdUtils, 'getTfxRunContext').mockImplementation(() => { - throw new Error('cannot find tfx run context'); - }); - getKfpRunContextSpy = jest.spyOn(MlmdUtils, 'getKfpRunContext').mockImplementation(() => { - throw new Error('cannot find kfp run context'); + getRunContextSpy = jest.spyOn(MlmdUtils, 'getRunContext').mockImplementation(() => { + throw new Error('cannot find run context'); }); // Hide expected warning messages warnSpy = jest.spyOn(Utils.logger, 'warn').mockImplementation(); diff --git a/frontend/src/pages/RunDetails.tsx b/frontend/src/pages/RunDetails.tsx index 95a7345e8eb..7665d121b2d 100644 --- a/frontend/src/pages/RunDetails.tsx +++ b/frontend/src/pages/RunDetails.tsx @@ -22,12 +22,7 @@ import * as React from 'react'; import { Link, Redirect } from 'react-router-dom'; import { GkeMetadata, GkeMetadataContext } from 'src/lib/GkeMetadata'; import { useNamespaceChangeEvent } from 'src/lib/KubeflowClient'; -import { - ExecutionHelpers, - getExecutionsFromContext, - getKfpRunContext, - getTfxRunContext, -} from 'src/lib/MlmdUtils'; +import { ExecutionHelpers, getExecutionsFromContext, getRunContext } from 'src/lib/MlmdUtils'; import { classes, stylesheet } from 'typestyle'; import { NodePhase as ArgoNodePhase, @@ -737,20 +732,13 @@ class RunDetails extends Page { let mlmdRunContext: Context | undefined; let mlmdExecutions: Execution[] | undefined; // Get data about this workflow from MLMD - if (workflow.metadata?.name) { - try { - try { - mlmdRunContext = await getTfxRunContext(workflow.metadata.name); - } catch (err) { - logger.warn(`Cannot find tfx run context (this is expected for non tfx runs)`, err); - mlmdRunContext = await getKfpRunContext(workflow.metadata.name); - } - mlmdExecutions = await getExecutionsFromContext(mlmdRunContext); - } catch (err) { - // Data in MLMD may not exist depending on this pipeline is a TFX pipeline. - // So we only log the error in console. - logger.warn(err); - } + try { + mlmdRunContext = await getRunContext(workflow); + mlmdExecutions = await getExecutionsFromContext(mlmdRunContext); + } catch (err) { + // Data in MLMD may not exist depending on this pipeline is a TFX pipeline. + // So we only log the error in console. + logger.warn(err); } // Build runtime graph